From e421e832883241bd7831bf77dc31d5fb31d7da58 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Sat, 6 Apr 2019 18:18:34 -0700
Subject: [PATCH] port nms extension from maskrcnn-benchmark (#469)

* port nms extension from maskrcnn-benchmark

* fix linting error
---
 .gitignore                                    |   2 +-
 compile.sh                                    |   6 +-
 mmdet/ops/nms/.gitignore                      |   1 -
 mmdet/ops/nms/Makefile                        |   8 -
 mmdet/ops/nms/cpu_nms.pyx                     |  70 -------
 mmdet/ops/nms/gpu_nms.hpp                     |   3 -
 mmdet/ops/nms/gpu_nms.pyx                     |  45 -----
 mmdet/ops/nms/nms_kernel.cu                   | 188 ------------------
 mmdet/ops/nms/nms_wrapper.py                  |  53 +++--
 mmdet/ops/nms/setup.py                        |  24 ++-
 mmdet/ops/nms/src/nms_cpu.cpp                 |  71 +++++++
 mmdet/ops/nms/src/nms_cuda.cpp                |  17 ++
 mmdet/ops/nms/src/nms_kernel.cu               | 131 ++++++++++++
 .../soft_nms_cpu.pyx}                         |   2 +-
 14 files changed, 276 insertions(+), 345 deletions(-)
 delete mode 100644 mmdet/ops/nms/.gitignore
 delete mode 100644 mmdet/ops/nms/Makefile
 delete mode 100644 mmdet/ops/nms/cpu_nms.pyx
 delete mode 100644 mmdet/ops/nms/gpu_nms.hpp
 delete mode 100644 mmdet/ops/nms/gpu_nms.pyx
 delete mode 100644 mmdet/ops/nms/nms_kernel.cu
 create mode 100644 mmdet/ops/nms/src/nms_cpu.cpp
 create mode 100644 mmdet/ops/nms/src/nms_cuda.cpp
 create mode 100644 mmdet/ops/nms/src/nms_kernel.cu
 rename mmdet/ops/nms/{cpu_soft_nms.pyx => src/soft_nms_cpu.pyx} (99%)

diff --git a/.gitignore b/.gitignore
index f189e1d..e7a290e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -104,7 +104,7 @@ venv.bak/
 .mypy_cache/
 
 # cython generated cpp
-mmdet/ops/nms/*.cpp
+mmdet/ops/nms/src/soft_nms_cpu.cpp
 mmdet/version.py
 data
 .vscode
diff --git a/compile.sh b/compile.sh
index 776de1f..9ae7d04 100755
--- a/compile.sh
+++ b/compile.sh
@@ -18,8 +18,10 @@ $PYTHON setup.py build_ext --inplace
 
 echo "Building nms op..."
 cd ../nms
-make clean
-make PYTHON=${PYTHON}
+if [ -d "build" ]; then
+    rm -r build
+fi
+$PYTHON setup.py build_ext --inplace
 
 echo "Building dcn..."
 cd ../dcn
diff --git a/mmdet/ops/nms/.gitignore b/mmdet/ops/nms/.gitignore
deleted file mode 100644
index ce1da4c..0000000
--- a/mmdet/ops/nms/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-*.cpp
diff --git a/mmdet/ops/nms/Makefile b/mmdet/ops/nms/Makefile
deleted file mode 100644
index af511f3..0000000
--- a/mmdet/ops/nms/Makefile
+++ /dev/null
@@ -1,8 +0,0 @@
-PYTHON=${PYTHON:-python}
-
-all:
-	echo "Compiling nms kernels..."
-	$(PYTHON) setup.py build_ext --inplace
-
-clean:
-	rm -f *.so
diff --git a/mmdet/ops/nms/cpu_nms.pyx b/mmdet/ops/nms/cpu_nms.pyx
deleted file mode 100644
index cccc58a..0000000
--- a/mmdet/ops/nms/cpu_nms.pyx
+++ /dev/null
@@ -1,70 +0,0 @@
-# --------------------------------------------------------
-# Fast R-CNN
-# Copyright (c) 2015 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-# Written by Ross Girshick
-# --------------------------------------------------------
-
-# cython: language_level=3, boundscheck=False
-
-import numpy as np
-cimport numpy as np
-
-cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
-    return a if a >= b else b
-
-cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
-    return a if a <= b else b
-
-def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
-    cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
-    cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
-    cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
-    cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
-    cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
-
-    cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
-    cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
-
-    cdef int ndets = dets.shape[0]
-    cdef np.ndarray[np.int_t, ndim=1] suppressed = \
-            np.zeros((ndets), dtype=np.int)
-
-    # nominal indices
-    cdef int _i, _j
-    # sorted indices
-    cdef int i, j
-    # temp variables for box i's (the box currently under consideration)
-    cdef np.float32_t ix1, iy1, ix2, iy2, iarea
-    # variables for computing overlap with box j (lower scoring box)
-    cdef np.float32_t xx1, yy1, xx2, yy2
-    cdef np.float32_t w, h
-    cdef np.float32_t inter, ovr
-
-    keep = []
-    for _i in range(ndets):
-        i = order[_i]
-        if suppressed[i] == 1:
-            continue
-        keep.append(i)
-        ix1 = x1[i]
-        iy1 = y1[i]
-        ix2 = x2[i]
-        iy2 = y2[i]
-        iarea = areas[i]
-        for _j in range(_i + 1, ndets):
-            j = order[_j]
-            if suppressed[j] == 1:
-                continue
-            xx1 = max(ix1, x1[j])
-            yy1 = max(iy1, y1[j])
-            xx2 = min(ix2, x2[j])
-            yy2 = min(iy2, y2[j])
-            w = max(0.0, xx2 - xx1 + 1)
-            h = max(0.0, yy2 - yy1 + 1)
-            inter = w * h
-            ovr = inter / (iarea + areas[j] - inter)
-            if ovr >= thresh:
-                suppressed[j] = 1
-
-    return keep
diff --git a/mmdet/ops/nms/gpu_nms.hpp b/mmdet/ops/nms/gpu_nms.hpp
deleted file mode 100644
index 2d45e34..0000000
--- a/mmdet/ops/nms/gpu_nms.hpp
+++ /dev/null
@@ -1,3 +0,0 @@
-void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
-          int boxes_dim, float nms_overlap_thresh, int device_id, size_t base);
-size_t nms_Malloc();
diff --git a/mmdet/ops/nms/gpu_nms.pyx b/mmdet/ops/nms/gpu_nms.pyx
deleted file mode 100644
index f2e7857..0000000
--- a/mmdet/ops/nms/gpu_nms.pyx
+++ /dev/null
@@ -1,45 +0,0 @@
-# --------------------------------------------------------
-# Faster R-CNN
-# Copyright (c) 2015 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-# Written by Ross Girshick
-# --------------------------------------------------------
-
-# cython: language_level=3, boundscheck=False
-
-import numpy as np
-cimport numpy as np
-
-assert sizeof(int) == sizeof(np.int32_t)
-
-cdef extern from "gpu_nms.hpp":
-    void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int, size_t) nogil
-    size_t nms_Malloc() nogil
-
-memory_pool = {}
-
-def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
-            np.int32_t device_id=0):
-    cdef int boxes_num = dets.shape[0]
-    cdef int boxes_dim = 5
-    cdef int num_out
-    cdef size_t base
-    cdef np.ndarray[np.int32_t, ndim=1] \
-        keep = np.zeros(boxes_num, dtype=np.int32)
-    cdef np.ndarray[np.float32_t, ndim=1] \
-        scores = dets[:, 4]
-    cdef np.ndarray[np.int_t, ndim=1] \
-        order = scores.argsort()[::-1]
-    cdef np.ndarray[np.float32_t, ndim=2] \
-        sorted_dets = dets[order, :5]
-    cdef float cthresh = thresh
-    if device_id not in memory_pool:
-        with nogil:
-            base = nms_Malloc()
-        memory_pool[device_id] = base
-        # print "malloc", base
-    base = memory_pool[device_id]
-    with nogil:
-        _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, cthresh, device_id, base)
-    keep = keep[:num_out]
-    return list(order[keep])
diff --git a/mmdet/ops/nms/nms_kernel.cu b/mmdet/ops/nms/nms_kernel.cu
deleted file mode 100644
index 4c5f0ec..0000000
--- a/mmdet/ops/nms/nms_kernel.cu
+++ /dev/null
@@ -1,188 +0,0 @@
-// ------------------------------------------------------------------
-// Faster R-CNN
-// Copyright (c) 2015 Microsoft
-// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
-// Written by Shaoqing Ren
-// ------------------------------------------------------------------
-
-#include <stdio.h>
-#include <iostream>
-#include <vector>
-#include "gpu_nms.hpp"
-
-#define CUDA_CHECK(condition)                                    \
-    /* Code block avoids redefinition of cudaError_t error */    \
-    do {                                                         \
-        cudaError_t error = condition;                           \
-        if (error != cudaSuccess) {                              \
-            std::cout << cudaGetErrorString(error) << std::endl; \
-        }                                                        \
-    } while (0)
-
-#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
-#define MULTIPLIER 16
-#define LONGLONG_SIZE 64
-
-int const threadsPerBlock =
-    sizeof(unsigned long long) * 8 *
-    MULTIPLIER;  // number of bits for a long long variable
-
-__device__ inline float devIoU(float const* const a, float const* const b) {
-    float left = max(a[0], b[0]), right = min(a[2], b[2]);
-    float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
-    float width = max(right - left + 1, 0.f),
-          height = max(bottom - top + 1, 0.f);
-    float interS = width * height;
-    float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
-    float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
-    return interS / (Sa + Sb - interS);
-}
-
-__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
-                           const float* dev_boxes,
-                           unsigned long long* dev_mask) {
-    const int row_start = blockIdx.y;
-    const int col_start = blockIdx.x;
-
-    // if (row_start > col_start) return;
-
-    const int row_size =
-        min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
-    const int col_size =
-        min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
-
-    __shared__ float block_boxes[threadsPerBlock * 5];
-    if (threadIdx.x < col_size) {
-        block_boxes[threadIdx.x * 5 + 0] =
-            dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
-        block_boxes[threadIdx.x * 5 + 1] =
-            dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
-        block_boxes[threadIdx.x * 5 + 2] =
-            dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
-        block_boxes[threadIdx.x * 5 + 3] =
-            dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
-        block_boxes[threadIdx.x * 5 + 4] =
-            dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
-    }
-    __syncthreads();
-
-    unsigned long long ts[MULTIPLIER];
-
-    if (threadIdx.x < row_size) {
-#pragma unroll
-        for (int i = 0; i < MULTIPLIER; ++i) {
-            ts[i] = 0;
-        }
-        const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
-        const float* cur_box = dev_boxes + cur_box_idx * 5;
-        int i = 0;
-        int start = 0;
-        if (row_start == col_start) {
-            start = threadIdx.x + 1;
-        }
-        for (i = start; i < col_size; i++) {
-            if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
-                ts[i / LONGLONG_SIZE] |= 1ULL << (i % LONGLONG_SIZE);
-            }
-        }
-        const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
-
-#pragma unroll
-        for (int i = 0; i < MULTIPLIER; ++i) {
-            dev_mask[(cur_box_idx * col_blocks + col_start) * MULTIPLIER + i] =
-                ts[i];
-        }
-    }
-}
-
-void _set_device(int device_id) {
-    int current_device;
-    CUDA_CHECK(cudaGetDevice(&current_device));
-    if (current_device == device_id) {
-        return;
-    }
-    // The call to cudaSetDevice must come before any calls to Get, which
-    // may perform initialization using the GPU.
-    CUDA_CHECK(cudaSetDevice(device_id));
-}
-
-const size_t MEMORY_SIZE = 500000000;
-size_t nms_Malloc() {
-    float* boxes_dev = NULL;
-    CUDA_CHECK(cudaMalloc(&boxes_dev, MEMORY_SIZE));
-    return size_t(boxes_dev);
-}
-
-void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
-          int boxes_dim, float nms_overlap_thresh, int device_id, size_t base) {
-    _set_device(device_id);
-
-    float* boxes_dev = NULL;
-    unsigned long long* mask_dev = NULL;
-
-    const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
-
-    if (base > 0) {
-        size_t require_mem =
-            boxes_num * boxes_dim * sizeof(float) +
-            boxes_num * col_blocks * sizeof(unsigned long long) * MULTIPLIER;
-        if (require_mem >= MEMORY_SIZE) {
-            std::cout << "require_mem: " << require_mem << std::endl;
-        }
-        boxes_dev = (float*)(base);
-        mask_dev =
-            (unsigned long long*)(base +
-                                  512 * ((unsigned long long)(boxes_num *
-                                                              boxes_dim *
-                                                              sizeof(float) /
-                                                              512) +
-                                         1));
-    } else {
-        CUDA_CHECK(
-            cudaMalloc(&boxes_dev, boxes_num * boxes_dim * sizeof(float)));
-        CUDA_CHECK(cudaMalloc(&mask_dev, MULTIPLIER * boxes_num * col_blocks *
-                                             sizeof(unsigned long long)));
-    }
-    CUDA_CHECK(cudaMemcpy(boxes_dev, boxes_host,
-                          boxes_num * boxes_dim * sizeof(float),
-                          cudaMemcpyHostToDevice));
-
-    dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
-                DIVUP(boxes_num, threadsPerBlock));
-    dim3 threads(threadsPerBlock);
-    nms_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes_dev,
-                                    mask_dev);
-
-    std::vector<unsigned long long> mask_host(boxes_num * col_blocks *
-                                              MULTIPLIER);
-    CUDA_CHECK(cudaMemcpy(
-        &mask_host[0], mask_dev,
-        sizeof(unsigned long long) * boxes_num * col_blocks * MULTIPLIER,
-        cudaMemcpyDeviceToHost));
-
-    std::vector<unsigned long long> remv(col_blocks * MULTIPLIER);
-    memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks * MULTIPLIER);
-
-    int num_to_keep = 0;
-    for (int i = 0; i < boxes_num; i++) {
-        int nblock = i / threadsPerBlock;
-        int inblock = i % threadsPerBlock;
-        int offset = inblock / LONGLONG_SIZE;
-        int bit_pos = inblock % LONGLONG_SIZE;
-
-        if (!(remv[nblock * MULTIPLIER + offset] & (1ULL << bit_pos))) {
-            keep_out[num_to_keep++] = i;
-            unsigned long long* p = &mask_host[0] + i * col_blocks * MULTIPLIER;
-            for (int j = nblock * MULTIPLIER + offset;
-                 j < col_blocks * MULTIPLIER; j++) {
-                remv[j] |= p[j];
-            }
-        }
-    }
-    *num_out = num_to_keep;
-
-    if (!base) {
-        CUDA_CHECK(cudaFree(boxes_dev));
-        CUDA_CHECK(cudaFree(mask_dev));
-    }
-}
diff --git a/mmdet/ops/nms/nms_wrapper.py b/mmdet/ops/nms/nms_wrapper.py
index 83b2858..8ce5bc4 100644
--- a/mmdet/ops/nms/nms_wrapper.py
+++ b/mmdet/ops/nms/nms_wrapper.py
@@ -1,36 +1,51 @@
 import numpy as np
 import torch
 
-from .gpu_nms import gpu_nms
-from .cpu_nms import cpu_nms
-from .cpu_soft_nms import cpu_soft_nms
+from . import nms_cuda, nms_cpu
+from .soft_nms_cpu import soft_nms_cpu
 
 
 def nms(dets, iou_thr, device_id=None):
-    """Dispatch to either CPU or GPU NMS implementations."""
+    """Dispatch to either CPU or GPU NMS implementations.
+
+    The input can be either a torch tensor or numpy array. GPU NMS will be used
+    if the input is a gpu tensor or device_id is specified, otherwise CPU NMS
+    will be used. The returned type will always be the same as inputs.
+
+    Arguments:
+        dets (torch.Tensor or np.ndarray): bboxes with scores.
+        iou_thr (float): IoU threshold for NMS.
+        device_id (int, optional): when `dets` is a numpy array, if `device_id`
+            is None, then cpu nms is used, otherwise gpu_nms will be used.
+
+    Returns:
+        tuple: kept bboxes and indice, which is always the same data type as
+            the input.
+    """
+    # convert dets (tensor or numpy array) to tensor
     if isinstance(dets, torch.Tensor):
-        is_tensor = True
-        if dets.is_cuda:
-            device_id = dets.get_device()
-        dets_np = dets.detach().cpu().numpy()
+        is_numpy = False
+        dets_th = dets
     elif isinstance(dets, np.ndarray):
-        is_tensor = False
-        dets_np = dets
+        is_numpy = True
+        device = 'cpu' if device_id is None else 'cuda:{}'.format(device_id)
+        dets_th = torch.from_numpy(dets).to(device)
     else:
         raise TypeError(
             'dets must be either a Tensor or numpy array, but got {}'.format(
                 type(dets)))
 
-    if dets_np.shape[0] == 0:
-        inds = []
+    # execute cpu or cuda nms
+    if dets_th.shape[0] == 0:
+        inds = dets_th.new_zeros(0, dtype=torch.long)
     else:
-        inds = (gpu_nms(dets_np, iou_thr, device_id=device_id)
-                if device_id is not None else cpu_nms(dets_np, iou_thr))
+        if dets_th.is_cuda:
+            inds = nms_cuda.nms(dets_th, iou_thr)
+        else:
+            inds = nms_cpu.nms(dets_th, iou_thr)
 
-    if is_tensor:
-        inds = dets.new_tensor(inds, dtype=torch.long)
-    else:
-        inds = np.array(inds, dtype=np.int64)
+    if is_numpy:
+        inds = inds.cpu().numpy()
     return dets[inds, :], inds
 
 
@@ -49,7 +64,7 @@ def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
     method_codes = {'linear': 1, 'gaussian': 2}
     if method not in method_codes:
         raise ValueError('Invalid method for SoftNMS: {}'.format(method))
-    new_dets, inds = cpu_soft_nms(
+    new_dets, inds = soft_nms_cpu(
         dets_np,
         iou_thr,
         method=method_codes[method],
diff --git a/mmdet/ops/nms/setup.py b/mmdet/ops/nms/setup.py
index a8fe373..28f3b4e 100644
--- a/mmdet/ops/nms/setup.py
+++ b/mmdet/ops/nms/setup.py
@@ -1,11 +1,11 @@
 import os.path as osp
-from distutils.core import setup, Extension
+from setuptools import setup, Extension
 
 import numpy as np
 from Cython.Build import cythonize
 from Cython.Distutils import build_ext
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
 
-# extensions
 ext_args = dict(
     include_dirs=[np.get_include()],
     language='c++',
@@ -16,9 +16,7 @@ ext_args = dict(
 )
 
 extensions = [
-    Extension('cpu_nms', ['cpu_nms.pyx'], **ext_args),
-    Extension('cpu_soft_nms', ['cpu_soft_nms.pyx'], **ext_args),
-    Extension('gpu_nms', ['gpu_nms.pyx', 'nms_kernel.cu'], **ext_args),
+    Extension('soft_nms_cpu', ['src/soft_nms_cpu.pyx'], **ext_args),
 ]
 
 
@@ -59,7 +57,6 @@ def customize_compiler_for_nvcc(self):
     self._compile = _compile
 
 
-# run the customize_compiler
 class custom_build_ext(build_ext):
 
     def build_extensions(self):
@@ -68,7 +65,20 @@ class custom_build_ext(build_ext):
 
 
 setup(
-    name='nms',
+    name='soft_nms',
     cmdclass={'build_ext': custom_build_ext},
     ext_modules=cythonize(extensions),
 )
+
+setup(
+    name='nms_cuda',
+    ext_modules=[
+        CUDAExtension('nms_cuda', [
+            'src/nms_cuda.cpp',
+            'src/nms_kernel.cu',
+        ]),
+        CUDAExtension('nms_cpu', [
+            'src/nms_cpu.cpp',
+        ]),
+    ],
+    cmdclass={'build_ext': BuildExtension})
diff --git a/mmdet/ops/nms/src/nms_cpu.cpp b/mmdet/ops/nms/src/nms_cpu.cpp
new file mode 100644
index 0000000..65546ef
--- /dev/null
+++ b/mmdet/ops/nms/src/nms_cpu.cpp
@@ -0,0 +1,71 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include <torch/extension.h>
+
+template <typename scalar_t>
+at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) {
+  AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
+
+  if (dets.numel() == 0) {
+    return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
+  }
+
+  auto x1_t = dets.select(1, 0).contiguous();
+  auto y1_t = dets.select(1, 1).contiguous();
+  auto x2_t = dets.select(1, 2).contiguous();
+  auto y2_t = dets.select(1, 3).contiguous();
+  auto scores = dets.select(1, 4).contiguous();
+
+  at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1);
+
+  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
+
+  auto ndets = dets.size(0);
+  at::Tensor suppressed_t =
+      at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU));
+
+  auto suppressed = suppressed_t.data<uint8_t>();
+  auto order = order_t.data<int64_t>();
+  auto x1 = x1_t.data<scalar_t>();
+  auto y1 = y1_t.data<scalar_t>();
+  auto x2 = x2_t.data<scalar_t>();
+  auto y2 = y2_t.data<scalar_t>();
+  auto areas = areas_t.data<scalar_t>();
+
+  for (int64_t _i = 0; _i < ndets; _i++) {
+    auto i = order[_i];
+    if (suppressed[i] == 1) continue;
+    auto ix1 = x1[i];
+    auto iy1 = y1[i];
+    auto ix2 = x2[i];
+    auto iy2 = y2[i];
+    auto iarea = areas[i];
+
+    for (int64_t _j = _i + 1; _j < ndets; _j++) {
+      auto j = order[_j];
+      if (suppressed[j] == 1) continue;
+      auto xx1 = std::max(ix1, x1[j]);
+      auto yy1 = std::max(iy1, y1[j]);
+      auto xx2 = std::min(ix2, x2[j]);
+      auto yy2 = std::min(iy2, y2[j]);
+
+      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1 + 1);
+      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1 + 1);
+      auto inter = w * h;
+      auto ovr = inter / (iarea + areas[j] - inter);
+      if (ovr >= threshold) suppressed[j] = 1;
+    }
+  }
+  return at::nonzero(suppressed_t == 0).squeeze(1);
+}
+
+at::Tensor nms(const at::Tensor& dets, const float threshold) {
+  at::Tensor result;
+  AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
+    result = nms_cpu_kernel<scalar_t>(dets, threshold);
+  });
+  return result;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("nms", &nms, "non-maximum suppression");
+}
\ No newline at end of file
diff --git a/mmdet/ops/nms/src/nms_cuda.cpp b/mmdet/ops/nms/src/nms_cuda.cpp
new file mode 100644
index 0000000..0ea6f9b
--- /dev/null
+++ b/mmdet/ops/nms/src/nms_cuda.cpp
@@ -0,0 +1,17 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include <torch/extension.h>
+
+#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
+
+at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
+
+at::Tensor nms(const at::Tensor& dets, const float threshold) {
+  CHECK_CUDA(dets);
+  if (dets.numel() == 0)
+    return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
+  return nms_cuda(dets, threshold);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("nms", &nms, "non-maximum suppression");
+}
\ No newline at end of file
diff --git a/mmdet/ops/nms/src/nms_kernel.cu b/mmdet/ops/nms/src/nms_kernel.cu
new file mode 100644
index 0000000..9254f2a
--- /dev/null
+++ b/mmdet/ops/nms/src/nms_kernel.cu
@@ -0,0 +1,131 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <vector>
+#include <iostream>
+
+int const threadsPerBlock = sizeof(unsigned long long) * 8;
+
+__device__ inline float devIoU(float const * const a, float const * const b) {
+  float left = max(a[0], b[0]), right = min(a[2], b[2]);
+  float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
+  float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
+  float interS = width * height;
+  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
+  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
+  return interS / (Sa + Sb - interS);
+}
+
+__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
+                           const float *dev_boxes, unsigned long long *dev_mask) {
+  const int row_start = blockIdx.y;
+  const int col_start = blockIdx.x;
+
+  // if (row_start > col_start) return;
+
+  const int row_size =
+        min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
+  const int col_size =
+        min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
+
+  __shared__ float block_boxes[threadsPerBlock * 5];
+  if (threadIdx.x < col_size) {
+    block_boxes[threadIdx.x * 5 + 0] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
+    block_boxes[threadIdx.x * 5 + 1] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
+    block_boxes[threadIdx.x * 5 + 2] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
+    block_boxes[threadIdx.x * 5 + 3] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
+    block_boxes[threadIdx.x * 5 + 4] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
+  }
+  __syncthreads();
+
+  if (threadIdx.x < row_size) {
+    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
+    const float *cur_box = dev_boxes + cur_box_idx * 5;
+    int i = 0;
+    unsigned long long t = 0;
+    int start = 0;
+    if (row_start == col_start) {
+      start = threadIdx.x + 1;
+    }
+    for (i = start; i < col_size; i++) {
+      if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
+        t |= 1ULL << i;
+      }
+    }
+    const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
+    dev_mask[cur_box_idx * col_blocks + col_start] = t;
+  }
+}
+
+// boxes is a N x 5 tensor
+at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
+  using scalar_t = float;
+  AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
+  auto scores = boxes.select(1, 4);
+  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
+  auto boxes_sorted = boxes.index_select(0, order_t);
+
+  int boxes_num = boxes.size(0);
+
+  const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
+
+  scalar_t* boxes_dev = boxes_sorted.data<scalar_t>();
+
+  THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
+
+  unsigned long long* mask_dev = NULL;
+  //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
+  //                      boxes_num * col_blocks * sizeof(unsigned long long)));
+
+  mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
+
+  dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
+              THCCeilDiv(boxes_num, threadsPerBlock));
+  dim3 threads(threadsPerBlock);
+  nms_kernel<<<blocks, threads>>>(boxes_num,
+                                  nms_overlap_thresh,
+                                  boxes_dev,
+                                  mask_dev);
+
+  std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
+  THCudaCheck(cudaMemcpy(&mask_host[0],
+                        mask_dev,
+                        sizeof(unsigned long long) * boxes_num * col_blocks,
+                        cudaMemcpyDeviceToHost));
+
+  std::vector<unsigned long long> remv(col_blocks);
+  memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
+
+  at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
+  int64_t* keep_out = keep.data<int64_t>();
+
+  int num_to_keep = 0;
+  for (int i = 0; i < boxes_num; i++) {
+    int nblock = i / threadsPerBlock;
+    int inblock = i % threadsPerBlock;
+
+    if (!(remv[nblock] & (1ULL << inblock))) {
+      keep_out[num_to_keep++] = i;
+      unsigned long long *p = &mask_host[0] + i * col_blocks;
+      for (int j = nblock; j < col_blocks; j++) {
+        remv[j] |= p[j];
+      }
+    }
+  }
+
+  THCudaFree(state, mask_dev);
+  // TODO improve this part
+  return std::get<0>(order_t.index({
+                       keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
+                         order_t.device(), keep.scalar_type())
+                     }).sort(0, false));
+}
\ No newline at end of file
diff --git a/mmdet/ops/nms/cpu_soft_nms.pyx b/mmdet/ops/nms/src/soft_nms_cpu.pyx
similarity index 99%
rename from mmdet/ops/nms/cpu_soft_nms.pyx
rename to mmdet/ops/nms/src/soft_nms_cpu.pyx
index 48cb023..c35f8f1 100644
--- a/mmdet/ops/nms/cpu_soft_nms.pyx
+++ b/mmdet/ops/nms/src/soft_nms_cpu.pyx
@@ -19,7 +19,7 @@ cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
     return a if a <= b else b
 
 
-def cpu_soft_nms(
+def soft_nms_cpu(
     np.ndarray[float, ndim=2] boxes_in,
     float iou_thr,
     unsigned int method=1,
-- 
GitLab