From 9ace2eee23400a334ed8a7337e4f1fdfd024af63 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Tue, 27 Nov 2018 15:40:12 +0800 Subject: [PATCH] support different nms methods --- configs/retinanet_r50_fpn_1x.py | 2 +- mmdet/models/single_stage_heads/retina_head.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/retinanet_r50_fpn_1x.py b/configs/retinanet_r50_fpn_1x.py index 079d6e5..77f67de 100644 --- a/configs/retinanet_r50_fpn_1x.py +++ b/configs/retinanet_r50_fpn_1x.py @@ -40,9 +40,9 @@ train_cfg = dict( debug=False) test_cfg = dict( nms_pre=1000, - nms_thr=0.5, min_bbox_size=0, score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), max_per_img=100) # dataset settings dataset_type = 'CocoDataset' diff --git a/mmdet/models/single_stage_heads/retina_head.py b/mmdet/models/single_stage_heads/retina_head.py index 8fd047f..eb2596d 100644 --- a/mmdet/models/single_stage_heads/retina_head.py +++ b/mmdet/models/single_stage_heads/retina_head.py @@ -282,6 +282,6 @@ class RetinaHead(nn.Module): padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores, - cfg.score_thr, cfg.nms_thr, + cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels -- GitLab