From e2c4ea3b0ee8b1027236c4acbcb5de9064c7d26f Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Sat, 14 Dec 2019 21:38:19 +0800 Subject: [PATCH] Collect env info for trouble shoting (#1812) * collect env info for trouble shoting * minor fix * update the issue template * fix the travis building * update setup classifiers --- .github/ISSUE_TEMPLATE/error-report.md | 12 ++--- .isort.cfg | 2 +- .travis.yml | 2 +- mmdet/ops/__init__.py | 4 +- mmdet/ops/utils/__init__.py | 7 +++ mmdet/ops/utils/src/compiling_info.cpp | 56 ++++++++++++++++++++++ setup.py | 17 +++++-- tools/collect_env.py | 64 ++++++++++++++++++++++++++ 8 files changed, 151 insertions(+), 13 deletions(-) create mode 100644 mmdet/ops/utils/__init__.py create mode 100644 mmdet/ops/utils/src/compiling_info.cpp create mode 100644 tools/collect_env.py diff --git a/.github/ISSUE_TEMPLATE/error-report.md b/.github/ISSUE_TEMPLATE/error-report.md index a5020bc..80e1cc5 100644 --- a/.github/ISSUE_TEMPLATE/error-report.md +++ b/.github/ISSUE_TEMPLATE/error-report.md @@ -25,13 +25,11 @@ A placeholder for the command. 3. What dataset did you use? **Environment** - - OS: [e.g., Ubuntu 16.04.6] - - GCC [e.g., 5.4.0] - - PyTorch version [e.g., 1.1.0] -- How you installed PyTorch [e.g., pip, conda, source] -- GPU model [e.g., 1080Ti, V100] -- CUDA and CUDNN version -- [optional] Other information that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) + +1. Please run `python tools/collect_env.py` to collect necessary environment infomation and paste it here. +2. You may add addition that may be helpful for locating the problem, such as + - How you installed PyTorch [e.g., pip, conda, source] + - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) **Error traceback** If applicable, paste the error trackback here. diff --git a/.isort.cfg b/.isort.cfg index 8c2eb09..2186a18 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,6 +3,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = Cython,albumentations,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch +known_third_party = Cython,albumentations,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/.travis.yml b/.travis.yml index a813c5e..d51fc0d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,7 @@ python: - "3.6" - "3.7" -env: CUDA=10.1.105-1 CUDA_SHORT=10.1 UBUNTU_VERSION=ubuntu1804 +env: CUDA=10.1.105-1 CUDA_SHORT=10.1 UBUNTU_VERSION=ubuntu1804 FORCE_CUDA=1 cache: pip # Ref to CUDA installation in Travis: https://github.com/jeremad/cuda-travis diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py index 4317899..5c6a1f3 100644 --- a/mmdet/ops/__init__.py +++ b/mmdet/ops/__init__.py @@ -8,6 +8,7 @@ from .nms import nms, soft_nms from .roi_align import RoIAlign, roi_align from .roi_pool import RoIPool, roi_pool from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss +from .utils import get_compiler_version, get_compiling_cuda_version __all__ = [ 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', @@ -15,5 +16,6 @@ __all__ = [ 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv', 'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss', - 'MaskedConv2d', 'ContextBlock' + 'MaskedConv2d', 'ContextBlock', 'get_compiler_version', + 'get_compiling_cuda_version' ] diff --git a/mmdet/ops/utils/__init__.py b/mmdet/ops/utils/__init__.py new file mode 100644 index 0000000..0244c0f --- /dev/null +++ b/mmdet/ops/utils/__init__.py @@ -0,0 +1,7 @@ +# from . import compiling_info +from .compiling_info import get_compiler_version, get_compiling_cuda_version + +# get_compiler_version = compiling_info.get_compiler_version +# get_compiling_cuda_version = compiling_info.get_compiling_cuda_version + +__all__ = ['get_compiler_version', 'get_compiling_cuda_version'] diff --git a/mmdet/ops/utils/src/compiling_info.cpp b/mmdet/ops/utils/src/compiling_info.cpp new file mode 100644 index 0000000..fd62aab --- /dev/null +++ b/mmdet/ops/utils/src/compiling_info.cpp @@ -0,0 +1,56 @@ +// modified from +// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/vision.cpp +#include <cuda_runtime_api.h> +#include <torch/extension.h> + +#ifdef WITH_CUDA +int get_cudart_version() { return CUDART_VERSION; } +#endif + +std::string get_compiling_cuda_version() { +#ifdef WITH_CUDA + std::ostringstream oss; + + // copied from + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 + auto printCudaStyleVersion = [&](int v) { + oss << (v / 1000) << "." << (v / 10 % 100); + if (v % 10 != 0) { + oss << "." << (v % 10); + } + }; + printCudaStyleVersion(get_cudart_version()); + return oss.str(); +#else + return std::string("not available"); +#endif +} + +// similar to +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp +std::string get_compiler_version() { + std::ostringstream ss; +#if defined(__GNUC__) +#ifndef __clang__ + { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } +#endif +#endif + +#if defined(__clang_major__) + { + ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." + << __clang_patchlevel__; + } +#endif + +#if defined(_MSC_VER) + { ss << "MSVC " << _MSC_FULL_VER; } +#endif + return ss.str(); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); + m.def("get_compiling_cuda_version", &get_compiling_cuda_version, + "get_compiling_cuda_version"); +} diff --git a/setup.py b/setup.py index 166181f..43a2ec2 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ import subprocess import time from setuptools import Extension, dist, find_packages, setup +import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1']) @@ -92,9 +93,17 @@ def get_version(): def make_cuda_ext(name, module, sources): + define_macros = [] + + if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + define_macros += [("WITH_CUDA", None)] + else: + raise EnvironmentError('CUDA is required to compile MMDetection!') + return CUDAExtension( name='{}.{}'.format(module, name), sources=[os.path.join(*module.split('.'), p) for p in sources], + define_macros=define_macros, extra_compile_args={ 'cxx': [], 'nvcc': [ @@ -146,18 +155,20 @@ if __name__ == '__main__': 'Development Status :: 4 - Beta', 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', ], license='Apache License 2.0', setup_requires=['pytest-runner', 'cython', 'numpy'], tests_require=['pytest', 'xdoctest'], install_requires=get_requirements(), ext_modules=[ + make_cuda_ext( + name='compiling_info', + module='mmdet.ops.utils', + sources=['src/compiling_info.cpp']), make_cython_ext( name='soft_nms_cpu', module='mmdet.ops.nms', diff --git a/tools/collect_env.py b/tools/collect_env.py new file mode 100644 index 0000000..81d6c7a --- /dev/null +++ b/tools/collect_env.py @@ -0,0 +1,64 @@ +import os.path as osp +import subprocess +import sys +from collections import defaultdict + +import cv2 +import mmcv +import torch +import torchvision + +import mmdet +from mmdet.ops import get_compiler_version, get_compiling_cuda_version + + +def collect_env(): + env_info = {} + env_info['sys.platform'] = sys.platform + env_info['Python'] = sys.version.replace('\n', '') + + cuda_available = torch.cuda.is_available() + env_info['CUDA available'] = cuda_available + + if cuda_available: + from torch.utils.cpp_extension import CUDA_HOME + env_info['CUDA_HOME'] = CUDA_HOME + + if CUDA_HOME is not None and osp.isdir(CUDA_HOME): + try: + nvcc = osp.join(CUDA_HOME, 'bin/nvcc') + nvcc = subprocess.check_output( + '"{}" -V | tail -n1'.format(nvcc), shell=True) + nvcc = nvcc.decode('utf-8').strip() + except subprocess.SubprocessError: + nvcc = 'Not Available' + env_info['NVCC'] = nvcc + + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + for name, devids in devices.items(): + env_info['GPU ' + ','.join(devids)] = name + + gcc = subprocess.check_output('gcc --version | head -n1', shell=True) + gcc = gcc.decode('utf-8').strip() + env_info['GCC'] = gcc + + env_info['PyTorch'] = torch.__version__ + env_info['PyTorch compiling details'] = torch.__config__.show() + + env_info['TorchVision'] = torchvision.__version__ + + env_info['OpenCV'] = cv2.__version__ + + env_info['MMCV'] = mmcv.__version__ + env_info['MMDetection'] = mmdet.__version__ + env_info['MMDetection Compiler'] = get_compiler_version() + env_info['MMDetection CUDA Compiler'] = get_compiling_cuda_version() + + for name, val in env_info.items(): + print('{}: {}'.format(name, val)) + + +if __name__ == "__main__": + collect_env() -- GitLab