From 1ed2d3542030896fd73064644e3459d7be70a0cc Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Fri, 5 Oct 2018 20:23:50 +0800
Subject: [PATCH] vis results for Mask RCNN and update test thr from 0.001 to
 0.05

---
 mmdet/models/detectors/mask_rcnn.py  | 27 +++++++++++++++++----------
 tools/configs/r50_fpn_frcnn_1x.py    |  2 +-
 tools/configs/r50_fpn_maskrcnn_1x.py |  2 +-
 3 files changed, 19 insertions(+), 12 deletions(-)

diff --git a/mmdet/models/detectors/mask_rcnn.py b/mmdet/models/detectors/mask_rcnn.py
index 684598a..25a363e 100644
--- a/mmdet/models/detectors/mask_rcnn.py
+++ b/mmdet/models/detectors/mask_rcnn.py
@@ -15,13 +15,20 @@ class MaskRCNN(TwoStageDetector):
                  test_cfg,
                  pretrained=None):
         super(MaskRCNN, self).__init__(
-                    backbone=backbone,
-                    neck=neck,
-                    rpn_head=rpn_head,
-                    bbox_roi_extractor=bbox_roi_extractor,
-                    bbox_head=bbox_head,
-                    mask_roi_extractor=mask_roi_extractor,
-                    mask_head=mask_head,
-                    train_cfg=train_cfg,
-                    test_cfg=test_cfg,
-                    pretrained=pretrained)
+            backbone=backbone,
+            neck=neck,
+            rpn_head=rpn_head,
+            bbox_roi_extractor=bbox_roi_extractor,
+            bbox_head=bbox_head,
+            mask_roi_extractor=mask_roi_extractor,
+            mask_head=mask_head,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+            pretrained=pretrained)
+
+    def show_result(self, data, result, img_norm_cfg, **kwargs):
+        # TODO: show segmentation masks
+        assert isinstance(result, tuple)
+        assert len(result) == 2  # (bbox_results, segm_results)
+        super(MaskRCNN, self).show_result(data, result[0], img_norm_cfg,
+                                          **kwargs)
diff --git a/tools/configs/r50_fpn_frcnn_1x.py b/tools/configs/r50_fpn_frcnn_1x.py
index 82082df..6ab3dbc 100644
--- a/tools/configs/r50_fpn_frcnn_1x.py
+++ b/tools/configs/r50_fpn_frcnn_1x.py
@@ -76,7 +76,7 @@ test_cfg = dict(
         max_num=2000,
         nms_thr=0.7,
         min_bbox_size=0),
-    rcnn=dict(score_thr=1e-3, max_per_img=100, nms_thr=0.5))
+    rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = '../data/coco/'
diff --git a/tools/configs/r50_fpn_maskrcnn_1x.py b/tools/configs/r50_fpn_maskrcnn_1x.py
index ad61857..677176c 100644
--- a/tools/configs/r50_fpn_maskrcnn_1x.py
+++ b/tools/configs/r50_fpn_maskrcnn_1x.py
@@ -89,7 +89,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=1e-3, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
+        score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = '../data/coco/'
-- 
GitLab