diff --git a/README.md b/README.md
index 3759c6524af47344be43cca593427dc7a63f4e82..5d886bdf3cebbcf158091a9b73dca586cb9957bf 100644
--- a/README.md
+++ b/README.md
@@ -38,6 +38,7 @@ This project is released under the [Apache 2.0 license](LICENSE).
 
 v0.5.3 (26/11/2018)
 - Add Cascade R-CNN and Cascade Mask R-CNN.
+- Add support for Soft-NMS in config files.
 
 v0.5.2 (21/10/2018)
 - Add support for custom datasets.
diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py
index 9f3f8b6ff2222924be543049166df834fbd6763d..538b468127d3a6319e5bae1151fec489ca266a61 100644
--- a/configs/cascade_mask_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py
@@ -152,7 +152,10 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5),
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5),
     keep_all_stages=False)
 # dataset settings
 dataset_type = 'CocoDataset'
diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py
index 5b4a70c9d86a835b6a7d270da0b80a49a904e6c3..69a2e520db643f2829af42f0cadd34c477494a4f 100644
--- a/configs/cascade_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_rcnn_r50_fpn_1x.py
@@ -137,7 +137,8 @@ test_cfg = dict(
         max_num=2000,
         nms_thr=0.7,
         min_bbox_size=0),
-    rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
     keep_all_stages=False)
 # dataset settings
 dataset_type = 'CocoDataset'
diff --git a/configs/fast_mask_rcnn_r50_fpn_1x.py b/configs/fast_mask_rcnn_r50_fpn_1x.py
index d930289f53f8745be812ef5fb1840fb564ba3d4a..8863ba68404119f2f7dc2f627c70c21c6a77e313 100644
--- a/configs/fast_mask_rcnn_r50_fpn_1x.py
+++ b/configs/fast_mask_rcnn_r50_fpn_1x.py
@@ -60,7 +60,10 @@ train_cfg = dict(
         debug=False))
 test_cfg = dict(
     rcnn=dict(
-        score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/fast_rcnn_r50_fpn_1x.py b/configs/fast_rcnn_r50_fpn_1x.py
index da223fc52e79c7a1da3b1bbf32d7d84c9eabd6f2..57394bc2895efaa203c2c8586d37ee3e0db300fa 100644
--- a/configs/fast_rcnn_r50_fpn_1x.py
+++ b/configs/fast_rcnn_r50_fpn_1x.py
@@ -46,7 +46,9 @@ train_cfg = dict(
             neg_balance_thr=0),
         pos_weight=-1,
         debug=False))
-test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
+test_cfg = dict(
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/faster_rcnn_r50_fpn_1x.py b/configs/faster_rcnn_r50_fpn_1x.py
index 31301db1feabcf42fc0acbd45f13ecb21d00420a..97899afc673c2346e6a9bd0d2b5b5196cde66eff 100644
--- a/configs/faster_rcnn_r50_fpn_1x.py
+++ b/configs/faster_rcnn_r50_fpn_1x.py
@@ -81,7 +81,11 @@ test_cfg = dict(
         max_num=2000,
         nms_thr=0.7,
         min_bbox_size=0),
-    rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/mask_rcnn_r50_fpn_1x.py b/configs/mask_rcnn_r50_fpn_1x.py
index 144521cc6b6e58fd8d469b40fb061ee6f6223efe..c2ef8faf563aa3b55ef86a7cf2a4c7ca5eb67519 100644
--- a/configs/mask_rcnn_r50_fpn_1x.py
+++ b/configs/mask_rcnn_r50_fpn_1x.py
@@ -94,7 +94,10 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py
index f619d2682a035344c6fda6974cd03c5cbfeb0f26..1f7c6f17e36af180c7dd78a48bf431a4ad85e226 100644
--- a/mmdet/core/post_processing/bbox_nms.py
+++ b/mmdet/core/post_processing/bbox_nms.py
@@ -1,9 +1,9 @@
 import torch
 
-from mmdet.ops import nms
+from mmdet.ops.nms import nms_wrapper
 
 
-def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
+def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=-1):
     """NMS for multi-class bboxes.
 
     Args:
@@ -21,6 +21,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
     """
     num_classes = multi_scores.shape[1]
     bboxes, labels = [], []
+    nms_cfg_ = nms_cfg.copy()
+    nms_type = nms_cfg_.pop('type', 'nms')
+    nms_op = getattr(nms_wrapper, nms_type)
     for i in range(1, num_classes):
         cls_inds = multi_scores[:, i] > score_thr
         if not cls_inds.any():
@@ -32,11 +35,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
             _bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]
         _scores = multi_scores[cls_inds, i]
         cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)
-        # perform nms
-        nms_keep = nms(cls_dets, nms_thr)
-        cls_dets = cls_dets[nms_keep, :]
+        cls_dets, _ = nms_op(cls_dets, **nms_cfg_)
         cls_labels = multi_bboxes.new_full(
-            (len(nms_keep), ), i - 1, dtype=torch.long)
+            (cls_dets.shape[0], ), i - 1, dtype=torch.long)
         bboxes.append(cls_dets)
         labels.append(cls_labels)
     if bboxes:
diff --git a/mmdet/core/post_processing/merge_augs.py b/mmdet/core/post_processing/merge_augs.py
index 00f65b049ccf2b00a0fee73cc64ac257415425ea..f97954b8a77ebe97cec74ee4420be953146c63e4 100644
--- a/mmdet/core/post_processing/merge_augs.py
+++ b/mmdet/core/post_processing/merge_augs.py
@@ -29,9 +29,7 @@ def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg):
                                               scale_factor, flip)
         recovered_proposals.append(_proposals)
     aug_proposals = torch.cat(recovered_proposals, dim=0)
-    nms_keep = nms(aug_proposals, rpn_test_cfg.nms_thr,
-                   aug_proposals.get_device())
-    merged_proposals = aug_proposals[nms_keep, :]
+    merged_proposals, _ = nms(aug_proposals, rpn_test_cfg.nms_thr)
     scores = merged_proposals[:, 4]
     _, order = scores.sort(0, descending=True)
     num = min(rpn_test_cfg.max_num, merged_proposals.shape[0])
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 0a576999918f191b26963aeb712854e97cec744a..828e8b0d28e8c3563c2b6bca5d6c8fa3eaacdbca 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -100,7 +100,7 @@ class BBoxHead(nn.Module):
                        img_shape,
                        scale_factor,
                        rescale=False,
-                       nms_cfg=None):
+                       cfg=None):
         if isinstance(cls_score, list):
             cls_score = sum(cls_score) / float(len(cls_score))
         scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
@@ -115,12 +115,11 @@ class BBoxHead(nn.Module):
         if rescale:
             bboxes /= scale_factor
 
-        if nms_cfg is None:
+        if cfg is None:
             return bboxes, scores
         else:
             det_bboxes, det_labels = multiclass_nms(
-                bboxes, scores, nms_cfg.score_thr, nms_cfg.nms_thr,
-                nms_cfg.max_per_img)
+                bboxes, scores, cfg.score_thr, cfg.nms, cfg.max_per_img)
 
             return det_bboxes, det_labels
 
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 91d3eaf4fe27a136fb1f874d66fb059388b7e364..c2ab5f9b5a1d2d5af314e2c5d7407e16192f19a2 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -218,7 +218,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                     img_shape,
                     scale_factor,
                     rescale=rescale,
-                    nms_cfg=rcnn_test_cfg)
+                    cfg=rcnn_test_cfg)
                 bbox_result = bbox2result(det_bboxes, det_labels,
                                           bbox_head.num_classes)
                 ms_bbox_result['stage{}'.format(i)] = bbox_result
@@ -256,7 +256,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             img_shape,
             scale_factor,
             rescale=rescale,
-            nms_cfg=rcnn_test_cfg)
+            cfg=rcnn_test_cfg)
         bbox_result = bbox2result(det_bboxes, det_labels,
                                   self.bbox_head[-1].num_classes)
         ms_bbox_result['ensemble'] = bbox_result
diff --git a/mmdet/models/detectors/test_mixins.py b/mmdet/models/detectors/test_mixins.py
index 38136f47545c49d88253fee321c91f9408058ca9..2baf100601d970397cc8800a81766dd20d7ff2e0 100644
--- a/mmdet/models/detectors/test_mixins.py
+++ b/mmdet/models/detectors/test_mixins.py
@@ -47,7 +47,7 @@ class BBoxTestMixin(object):
             img_shape,
             scale_factor,
             rescale=rescale,
-            nms_cfg=rcnn_test_cfg)
+            cfg=rcnn_test_cfg)
         return det_bboxes, det_labels
 
     def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
@@ -73,15 +73,15 @@ class BBoxTestMixin(object):
                 img_shape,
                 scale_factor,
                 rescale=False,
-                nms_cfg=None)
+                cfg=None)
             aug_bboxes.append(bboxes)
             aug_scores.append(scores)
         # after merging, bboxes will be rescaled to the original image size
         merged_bboxes, merged_scores = merge_aug_bboxes(
-            aug_bboxes, aug_scores, img_metas, self.test_cfg.rcnn)
+            aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
         det_bboxes, det_labels = multiclass_nms(
-            merged_bboxes, merged_scores, self.test_cfg.rcnn.score_thr,
-            self.test_cfg.rcnn.nms_thr, self.test_cfg.rcnn.max_per_img)
+            merged_bboxes, merged_scores, rcnn_test_cfg.score_thr,
+            rcnn_test_cfg.nms, rcnn_test_cfg.max_per_img)
         return det_bboxes, det_labels
 
 
diff --git a/mmdet/models/rpn_heads/rpn_head.py b/mmdet/models/rpn_heads/rpn_head.py
index ad06e40c8fcee1b2991029d8121b9ffad54cbf32..3da0619083d130ebc1149c02888f858e76b027a3 100644
--- a/mmdet/models/rpn_heads/rpn_head.py
+++ b/mmdet/models/rpn_heads/rpn_head.py
@@ -234,13 +234,13 @@ class RPNHead(nn.Module):
             proposals = proposals[valid_inds, :]
             scores = scores[valid_inds]
             proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
-            nms_keep = nms(proposals, cfg.nms_thr)[:cfg.nms_post]
-            proposals = proposals[nms_keep, :]
+            proposals, _ = nms(proposals, cfg.nms_thr)
+            proposals = proposals[:cfg.nms_post, :]
             mlvl_proposals.append(proposals)
         proposals = torch.cat(mlvl_proposals, 0)
         if cfg.nms_across_levels:
-            nms_keep = nms(proposals, cfg.nms_thr)[:cfg.max_num]
-            proposals = proposals[nms_keep, :]
+            proposals, _ = nms(proposals, cfg.nms_thr)
+            proposals = proposals[:cfg.max_num, :]
         else:
             scores = proposals[:, 4]
             _, order = scores.sort(0, descending=True)
diff --git a/mmdet/ops/nms/cpu_soft_nms.pyx b/mmdet/ops/nms/cpu_soft_nms.pyx
index 05ec5a5446221d3593a10edfd4d714bfa6192309..189dcee366acad23c05ecb286ce1e0915122c564 100644
--- a/mmdet/ops/nms/cpu_soft_nms.pyx
+++ b/mmdet/ops/nms/cpu_soft_nms.pyx
@@ -3,6 +3,7 @@
 # Copyright (c) University of Maryland, College Park
 # Licensed under The MIT License [see LICENSE for details]
 # Written by Navaneeth Bodla and Bharat Singh
+# Modified by Kai Chen
 # ----------------------------------------------------------
 
 import numpy as np
@@ -15,12 +16,13 @@ cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
 cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
     return a if a <= b else b
 
+
 def cpu_soft_nms(
     np.ndarray[float, ndim=2] boxes_in,
+    float iou_thr,
+    unsigned int method=1,
     float sigma=0.5,
-    float Nt=0.3,
-    float threshold=0.001,
-    unsigned int method=0
+    float min_score=0.001,
 ):
     boxes = boxes_in.copy()
     cdef unsigned int N = boxes.shape[0]
@@ -36,11 +38,11 @@ def cpu_soft_nms(
         maxscore = boxes[i, 4]
         maxpos = i
 
-        tx1 = boxes[i,0]
-        ty1 = boxes[i,1]
-        tx2 = boxes[i,2]
-        ty2 = boxes[i,3]
-        ts = boxes[i,4]
+        tx1 = boxes[i, 0]
+        ty1 = boxes[i, 1]
+        tx2 = boxes[i, 2]
+        ty2 = boxes[i, 3]
+        ts = boxes[i, 4]
         ti = inds[i]
 
         pos = i + 1
@@ -52,26 +54,26 @@ def cpu_soft_nms(
             pos = pos + 1
 
         # add max box as a detection
-        boxes[i,0] = boxes[maxpos,0]
-        boxes[i,1] = boxes[maxpos,1]
-        boxes[i,2] = boxes[maxpos,2]
-        boxes[i,3] = boxes[maxpos,3]
-        boxes[i,4] = boxes[maxpos,4]
+        boxes[i, 0] = boxes[maxpos, 0]
+        boxes[i, 1] = boxes[maxpos, 1]
+        boxes[i, 2] = boxes[maxpos, 2]
+        boxes[i, 3] = boxes[maxpos, 3]
+        boxes[i, 4] = boxes[maxpos, 4]
         inds[i] = inds[maxpos]
 
         # swap ith box with position of max box
-        boxes[maxpos,0] = tx1
-        boxes[maxpos,1] = ty1
-        boxes[maxpos,2] = tx2
-        boxes[maxpos,3] = ty2
-        boxes[maxpos,4] = ts
+        boxes[maxpos, 0] = tx1
+        boxes[maxpos, 1] = ty1
+        boxes[maxpos, 2] = tx2
+        boxes[maxpos, 3] = ty2
+        boxes[maxpos, 4] = ts
         inds[maxpos] = ti
 
-        tx1 = boxes[i,0]
-        ty1 = boxes[i,1]
-        tx2 = boxes[i,2]
-        ty2 = boxes[i,3]
-        ts = boxes[i,4]
+        tx1 = boxes[i, 0]
+        ty1 = boxes[i, 1]
+        tx2 = boxes[i, 2]
+        ty2 = boxes[i, 3]
+        ts = boxes[i, 4]
 
         pos = i + 1
         # NMS iterations, note that N changes if detection boxes fall below
@@ -89,35 +91,35 @@ def cpu_soft_nms(
                 ih = (min(ty2, y2) - max(ty1, y1) + 1)
                 if ih > 0:
                     ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
-                    ov = iw * ih / ua #iou between max box and detection box
+                    ov = iw * ih / ua  # iou between max box and detection box
 
-                    if method == 1: # linear
-                        if ov > Nt:
+                    if method == 1:  # linear
+                        if ov > iou_thr:
                             weight = 1 - ov
                         else:
                             weight = 1
-                    elif method == 2: # gaussian
-                        weight = np.exp(-(ov * ov)/sigma)
-                    else: # original NMS
-                        if ov > Nt:
+                    elif method == 2:  # gaussian
+                        weight = np.exp(-(ov * ov) / sigma)
+                    else:  # original NMS
+                        if ov > iou_thr:
                             weight = 0
                         else:
                             weight = 1
 
-                    boxes[pos, 4] = weight*boxes[pos, 4]
+                    boxes[pos, 4] = weight * boxes[pos, 4]
 
                     # if box score falls below threshold, discard the box by
                     # swapping with last box update N
-                    if boxes[pos, 4] < threshold:
-                        boxes[pos,0] = boxes[N-1, 0]
-                        boxes[pos,1] = boxes[N-1, 1]
-                        boxes[pos,2] = boxes[N-1, 2]
-                        boxes[pos,3] = boxes[N-1, 3]
-                        boxes[pos,4] = boxes[N-1, 4]
-                        inds[pos] = inds[N-1]
+                    if boxes[pos, 4] < min_score:
+                        boxes[pos, 0] = boxes[N-1, 0]
+                        boxes[pos, 1] = boxes[N-1, 1]
+                        boxes[pos, 2] = boxes[N-1, 2]
+                        boxes[pos, 3] = boxes[N-1, 3]
+                        boxes[pos, 4] = boxes[N-1, 4]
+                        inds[pos] = inds[N - 1]
                         N = N - 1
                         pos = pos - 1
 
             pos = pos + 1
 
-    return boxes[:N], inds[:N]
\ No newline at end of file
+    return boxes[:N], inds[:N]
diff --git a/mmdet/ops/nms/nms_wrapper.py b/mmdet/ops/nms/nms_wrapper.py
index 3978773b842a8236dedc8b942ad1d3100061813c..83b2858cdbe62e9bf12fb0728703afaeb1bb846f 100644
--- a/mmdet/ops/nms/nms_wrapper.py
+++ b/mmdet/ops/nms/nms_wrapper.py
@@ -6,43 +6,58 @@ from .cpu_nms import cpu_nms
 from .cpu_soft_nms import cpu_soft_nms
 
 
-def nms(dets, thresh, device_id=None):
+def nms(dets, iou_thr, device_id=None):
     """Dispatch to either CPU or GPU NMS implementations."""
-
-    tensor_device = None
     if isinstance(dets, torch.Tensor):
-        tensor_device = dets.device
+        is_tensor = True
         if dets.is_cuda:
             device_id = dets.get_device()
-        dets = dets.detach().cpu().numpy()
-    assert isinstance(dets, np.ndarray)
+        dets_np = dets.detach().cpu().numpy()
+    elif isinstance(dets, np.ndarray):
+        is_tensor = False
+        dets_np = dets
+    else:
+        raise TypeError(
+            'dets must be either a Tensor or numpy array, but got {}'.format(
+                type(dets)))
 
-    if dets.shape[0] == 0:
+    if dets_np.shape[0] == 0:
         inds = []
     else:
-        inds = (gpu_nms(dets, thresh, device_id=device_id)
-                if device_id is not None else cpu_nms(dets, thresh))
+        inds = (gpu_nms(dets_np, iou_thr, device_id=device_id)
+                if device_id is not None else cpu_nms(dets_np, iou_thr))
 
-    if tensor_device:
-        return torch.Tensor(inds).long().to(tensor_device)
+    if is_tensor:
+        inds = dets.new_tensor(inds, dtype=torch.long)
     else:
-        return np.array(inds, dtype=np.int)
+        inds = np.array(inds, dtype=np.int64)
+    return dets[inds, :], inds
 
 
-def soft_nms(dets, Nt=0.3, method=1, sigma=0.5, min_score=0):
+def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
     if isinstance(dets, torch.Tensor):
-        _dets = dets.detach().cpu().numpy()
+        is_tensor = True
+        dets_np = dets.detach().cpu().numpy()
+    elif isinstance(dets, np.ndarray):
+        is_tensor = False
+        dets_np = dets
     else:
-        _dets = dets.copy()
-    assert isinstance(_dets, np.ndarray)
+        raise TypeError(
+            'dets must be either a Tensor or numpy array, but got {}'.format(
+                type(dets)))
 
+    method_codes = {'linear': 1, 'gaussian': 2}
+    if method not in method_codes:
+        raise ValueError('Invalid method for SoftNMS: {}'.format(method))
     new_dets, inds = cpu_soft_nms(
-        _dets, Nt=Nt, method=method, sigma=sigma, threshold=min_score)
-
-    if isinstance(dets, torch.Tensor):
-        return dets.new_tensor(
-            inds, dtype=torch.long), dets.new_tensor(new_dets)
+        dets_np,
+        iou_thr,
+        method=method_codes[method],
+        sigma=sigma,
+        min_score=min_score)
+
+    if is_tensor:
+        return dets.new_tensor(new_dets), dets.new_tensor(
+            inds, dtype=torch.long)
     else:
-        return np.array(
-            inds, dtype=np.int), np.array(
-                new_dets, dtype=np.float32)
+        return new_dets.astype(np.float32), inds.astype(np.int64)