diff --git a/configs/retinanet_r50_fpn_1x.py b/configs/retinanet_r50_fpn_1x.py index 079d6e53ba92197163374851f421c54adff00162..77f67deccb03e05f496d1065d53710999041d8f9 100644 --- a/configs/retinanet_r50_fpn_1x.py +++ b/configs/retinanet_r50_fpn_1x.py @@ -40,9 +40,9 @@ train_cfg = dict( debug=False) test_cfg = dict( nms_pre=1000, - nms_thr=0.5, min_bbox_size=0, score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), max_per_img=100) # dataset settings dataset_type = 'CocoDataset' diff --git a/mmdet/models/single_stage_heads/retina_head.py b/mmdet/models/single_stage_heads/retina_head.py index 8fd047f99691c06a60823236a338fb97a46c78b7..eb2596d31e152c3fe7f3b6152812046b41a73462 100644 --- a/mmdet/models/single_stage_heads/retina_head.py +++ b/mmdet/models/single_stage_heads/retina_head.py @@ -282,6 +282,6 @@ class RetinaHead(nn.Module): padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores, - cfg.score_thr, cfg.nms_thr, + cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels