Skip to content
Snippets Groups Projects
Unverified Commit c21ff089 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #35 from hellock/master

Bug fix for compiling nms op
parents 0bd74f77 f3768bcd
No related branches found
No related tags found
No related merge requests found
......@@ -5,4 +5,4 @@ all:
$(PYTHON) setup.py build_ext --inplace
clean:
rm *.so
rm -f *.so
import os
from distutils.core import setup
from distutils.extension import Extension
import os.path as osp
from distutils.core import setup, Extension
import numpy as np
from Cython.Build import cythonize
from Cython.Distutils import build_ext
CUDA_ROOT = '/usr/local/cuda'
CUDA = {
"include": os.path.join(CUDA_ROOT, 'include'),
"lib": os.path.join(CUDA_ROOT, 'lib64'),
"nvcc": os.path.join(CUDA_ROOT, 'bin', "nvcc")
}
inc_dirs = [CUDA['include'], np.get_include()]
lib_dirs = [CUDA['lib']]
# extensions
ext_args = dict(
include_dirs=inc_dirs,
library_dirs=lib_dirs,
include_dirs=[np.get_include()],
language='c++',
libraries=['cudart'],
extra_compile_args={
"cc": ['-Wno-unused-function', '-Wno-write-strings'],
"nvcc": [
'-arch=sm_52', '--ptxas-options=-v', '-c', '--compiler-options',
'-fPIC'
],
'cc': ['-Wno-unused-function', '-Wno-write-strings'],
'nvcc': ['-c', '--compiler-options', '-fPIC'],
},
)
extensions = [
Extension('cpu_nms', ['cpu_nms.pyx'], **ext_args),
Extension('gpu_nms', ['gpu_nms.pyx', 'nms_kernel.cu'], **ext_args),
Extension('cpu_soft_nms', ['cpu_soft_nms.pyx'], **ext_args),
Extension('gpu_nms', ['gpu_nms.pyx', 'nms_kernel.cu'], **ext_args),
]
......@@ -59,9 +42,9 @@ def customize_compiler_for_nvcc(self):
# object but distutils doesn't have the ability to change compilers
# based on source extension: we add it.
def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
if os.path.splitext(src)[1] == '.cu':
if osp.splitext(src)[1] == '.cu':
# use the cuda for .cu files
self.set_executable('compiler_so', CUDA['nvcc'])
self.set_executable('compiler_so', 'nvcc')
# use only a subset of the extra_postargs, which are 1-1 translated
# from the extra_compile_args in the Extension class
postargs = extra_postargs['nvcc']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment