From 9ace2eee23400a334ed8a7337e4f1fdfd024af63 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Tue, 27 Nov 2018 15:40:12 +0800
Subject: [PATCH] support different nms methods

---
 configs/retinanet_r50_fpn_1x.py                | 2 +-
 mmdet/models/single_stage_heads/retina_head.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/configs/retinanet_r50_fpn_1x.py b/configs/retinanet_r50_fpn_1x.py
index 079d6e5..77f67de 100644
--- a/configs/retinanet_r50_fpn_1x.py
+++ b/configs/retinanet_r50_fpn_1x.py
@@ -40,9 +40,9 @@ train_cfg = dict(
     debug=False)
 test_cfg = dict(
     nms_pre=1000,
-    nms_thr=0.5,
     min_bbox_size=0,
     score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
     max_per_img=100)
 # dataset settings
 dataset_type = 'CocoDataset'
diff --git a/mmdet/models/single_stage_heads/retina_head.py b/mmdet/models/single_stage_heads/retina_head.py
index 8fd047f..eb2596d 100644
--- a/mmdet/models/single_stage_heads/retina_head.py
+++ b/mmdet/models/single_stage_heads/retina_head.py
@@ -282,6 +282,6 @@ class RetinaHead(nn.Module):
         padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
         mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
         det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores,
-                                                cfg.score_thr, cfg.nms_thr,
+                                                cfg.score_thr, cfg.nms,
                                                 cfg.max_per_img)
         return det_bboxes, det_labels
-- 
GitLab