From 364698b6bc44288e12e2a432871c17ada8dde42c Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Sat, 14 Dec 2019 18:16:26 +0800
Subject: [PATCH] Remove keep all stage code in HTC and Cascade RCNN (#1806)

* Remove keep all stage code

* remove keep_all_stage in config
---
 configs/cascade_mask_rcnn_r101_fpn_1x.py      |  3 +-
 configs/cascade_mask_rcnn_r50_caffe_c4_1x.py  |  3 +-
 configs/cascade_mask_rcnn_r50_fpn_1x.py       |  3 +-
 .../cascade_mask_rcnn_x101_32x4d_fpn_1x.py    |  3 +-
 .../cascade_mask_rcnn_x101_64x4d_fpn_1x.py    |  3 +-
 configs/cascade_rcnn_r101_fpn_1x.py           |  3 +-
 configs/cascade_rcnn_r50_caffe_c4_1x.py       |  3 +-
 configs/cascade_rcnn_r50_fpn_1x.py            |  3 +-
 configs/cascade_rcnn_x101_32x4d_fpn_1x.py     |  3 +-
 configs/cascade_rcnn_x101_64x4d_fpn_1x.py     |  3 +-
 ...ascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py |  3 +-
 .../cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py    |  3 +-
 .../cascade_mask_rcnn_hrnetv2p_w32_20e.py     |  3 +-
 .../hrnet/cascade_rcnn_hrnetv2p_w32_20e.py    |  3 +-
 configs/hrnet/htc_hrnetv2p_w32_20e.py         |  3 +-
 ...-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py |  3 +-
 configs/htc/htc_r101_fpn_20e.py               |  3 +-
 configs/htc/htc_r50_fpn_1x.py                 |  3 +-
 configs/htc/htc_r50_fpn_20e.py                |  3 +-
 .../htc/htc_without_semantic_r50_fpn_1x.py    |  3 +-
 configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py   |  3 +-
 configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py   |  3 +-
 mmdet/models/detectors/cascade_rcnn.py        | 51 ++-----------------
 mmdet/models/detectors/htc.py                 | 45 ++--------------
 24 files changed, 28 insertions(+), 134 deletions(-)

diff --git a/configs/cascade_mask_rcnn_r101_fpn_1x.py b/configs/cascade_mask_rcnn_r101_fpn_1x.py
index 0ad9c88..5ac4075 100644
--- a/configs/cascade_mask_rcnn_r101_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_r101_fpn_1x.py
@@ -174,8 +174,7 @@ test_cfg = dict(
         score_thr=0.05,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_mask_rcnn_r50_caffe_c4_1x.py b/configs/cascade_mask_rcnn_r50_caffe_c4_1x.py
index dd5f356..7ef36d5 100644
--- a/configs/cascade_mask_rcnn_r50_caffe_c4_1x.py
+++ b/configs/cascade_mask_rcnn_r50_caffe_c4_1x.py
@@ -176,8 +176,7 @@ test_cfg = dict(
         score_thr=0.05,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py
index c9f007e..e23e159 100644
--- a/configs/cascade_mask_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py
@@ -174,8 +174,7 @@ test_cfg = dict(
         score_thr=0.05,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_mask_rcnn_x101_32x4d_fpn_1x.py b/configs/cascade_mask_rcnn_x101_32x4d_fpn_1x.py
index 3167be4..723462c 100644
--- a/configs/cascade_mask_rcnn_x101_32x4d_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_x101_32x4d_fpn_1x.py
@@ -176,8 +176,7 @@ test_cfg = dict(
         score_thr=0.05,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_mask_rcnn_x101_64x4d_fpn_1x.py b/configs/cascade_mask_rcnn_x101_64x4d_fpn_1x.py
index 0c5434e..b8ad462 100644
--- a/configs/cascade_mask_rcnn_x101_64x4d_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_x101_64x4d_fpn_1x.py
@@ -176,8 +176,7 @@ test_cfg = dict(
         score_thr=0.05,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_rcnn_r101_fpn_1x.py b/configs/cascade_rcnn_r101_fpn_1x.py
index a790c2b..9f3d088 100644
--- a/configs/cascade_rcnn_r101_fpn_1x.py
+++ b/configs/cascade_rcnn_r101_fpn_1x.py
@@ -155,8 +155,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
-    keep_all_stages=False)
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_rcnn_r50_caffe_c4_1x.py b/configs/cascade_rcnn_r50_caffe_c4_1x.py
index 0dd10ab..5722d41 100644
--- a/configs/cascade_rcnn_r50_caffe_c4_1x.py
+++ b/configs/cascade_rcnn_r50_caffe_c4_1x.py
@@ -164,8 +164,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
-    keep_all_stages=False)
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py
index 96269fa..56e3ee9 100644
--- a/configs/cascade_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_rcnn_r50_fpn_1x.py
@@ -155,8 +155,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
-    keep_all_stages=False)
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_rcnn_x101_32x4d_fpn_1x.py b/configs/cascade_rcnn_x101_32x4d_fpn_1x.py
index 6de3d37..397d0b8 100644
--- a/configs/cascade_rcnn_x101_32x4d_fpn_1x.py
+++ b/configs/cascade_rcnn_x101_32x4d_fpn_1x.py
@@ -157,8 +157,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
-    keep_all_stages=False)
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/cascade_rcnn_x101_64x4d_fpn_1x.py b/configs/cascade_rcnn_x101_64x4d_fpn_1x.py
index d6e9d1f..0d31f88 100644
--- a/configs/cascade_rcnn_x101_64x4d_fpn_1x.py
+++ b/configs/cascade_rcnn_x101_64x4d_fpn_1x.py
@@ -157,8 +157,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
-    keep_all_stages=False)
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py b/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py
index 27476d3..c27fff2 100644
--- a/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py
+++ b/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py
@@ -177,8 +177,7 @@ test_cfg = dict(
         score_thr=0.05,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py b/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py
index 9f9f10c..2a4740b 100644
--- a/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py
+++ b/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py
@@ -158,8 +158,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
-    keep_all_stages=False)
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/hrnet/cascade_mask_rcnn_hrnetv2p_w32_20e.py b/configs/hrnet/cascade_mask_rcnn_hrnetv2p_w32_20e.py
index 17fe945..ae76e5a 100644
--- a/configs/hrnet/cascade_mask_rcnn_hrnetv2p_w32_20e.py
+++ b/configs/hrnet/cascade_mask_rcnn_hrnetv2p_w32_20e.py
@@ -190,8 +190,7 @@ test_cfg = dict(
         score_thr=0.05,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/hrnet/cascade_rcnn_hrnetv2p_w32_20e.py b/configs/hrnet/cascade_rcnn_hrnetv2p_w32_20e.py
index 65eedd1..48c0137 100644
--- a/configs/hrnet/cascade_rcnn_hrnetv2p_w32_20e.py
+++ b/configs/hrnet/cascade_rcnn_hrnetv2p_w32_20e.py
@@ -171,8 +171,7 @@ test_cfg = dict(
         nms_thr=0.7,
         min_bbox_size=0),
     rcnn=dict(
-        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
-    keep_all_stages=False)
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/hrnet/htc_hrnetv2p_w32_20e.py b/configs/hrnet/htc_hrnetv2p_w32_20e.py
index b1f9ff5..8279f24 100644
--- a/configs/hrnet/htc_hrnetv2p_w32_20e.py
+++ b/configs/hrnet/htc_hrnetv2p_w32_20e.py
@@ -207,8 +207,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py b/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py
index f06904c..2072f29 100644
--- a/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py
+++ b/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py
@@ -199,8 +199,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/htc/htc_r101_fpn_20e.py b/configs/htc/htc_r101_fpn_20e.py
index 36584a3..661c564 100644
--- a/configs/htc/htc_r101_fpn_20e.py
+++ b/configs/htc/htc_r101_fpn_20e.py
@@ -191,8 +191,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/htc/htc_r50_fpn_1x.py b/configs/htc/htc_r50_fpn_1x.py
index d77d60c..4945f2e 100644
--- a/configs/htc/htc_r50_fpn_1x.py
+++ b/configs/htc/htc_r50_fpn_1x.py
@@ -191,8 +191,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/htc/htc_r50_fpn_20e.py b/configs/htc/htc_r50_fpn_20e.py
index 9bc49af..ccf73a0 100644
--- a/configs/htc/htc_r50_fpn_20e.py
+++ b/configs/htc/htc_r50_fpn_20e.py
@@ -191,8 +191,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/htc/htc_without_semantic_r50_fpn_1x.py b/configs/htc/htc_without_semantic_r50_fpn_1x.py
index 2a4b777..de5dfcc 100644
--- a/configs/htc/htc_without_semantic_r50_fpn_1x.py
+++ b/configs/htc/htc_without_semantic_r50_fpn_1x.py
@@ -176,8 +176,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py b/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py
index 21db137..915a54e 100644
--- a/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py
+++ b/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py
@@ -193,8 +193,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py b/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py
index 6c5dada..99ceefc 100644
--- a/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py
+++ b/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py
@@ -193,8 +193,7 @@ test_cfg = dict(
         score_thr=0.001,
         nms=dict(type='nms', iou_thr=0.5),
         max_per_img=100,
-        mask_thr_binary=0.5),
-    keep_all_stages=False)
+        mask_thr_binary=0.5))
 # dataset settings
 dataset_type = 'CocoDataset'
 data_root = 'data/coco/'
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 86e971b..e79b189 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -339,41 +339,6 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             cls_score, bbox_pred = bbox_head(bbox_feats)
             ms_scores.append(cls_score)
 
-            if self.test_cfg.keep_all_stages:
-                det_bboxes, det_labels = bbox_head.get_det_bboxes(
-                    rois,
-                    cls_score,
-                    bbox_pred,
-                    img_shape,
-                    scale_factor,
-                    rescale=rescale,
-                    cfg=rcnn_test_cfg)
-                bbox_result = bbox2result(det_bboxes, det_labels,
-                                          bbox_head.num_classes)
-                ms_bbox_result['stage{}'.format(i)] = bbox_result
-
-                if self.with_mask:
-                    mask_roi_extractor = self.mask_roi_extractor[i]
-                    mask_head = self.mask_head[i]
-                    if det_bboxes.shape[0] == 0:
-                        mask_classes = mask_head.num_classes - 1
-                        segm_result = [[] for _ in range(mask_classes)]
-                    else:
-                        _bboxes = (
-                            det_bboxes[:, :4] *
-                            scale_factor if rescale else det_bboxes)
-                        mask_rois = bbox2roi([_bboxes])
-                        mask_feats = mask_roi_extractor(
-                            x[:len(mask_roi_extractor.featmap_strides)],
-                            mask_rois)
-                        if self.with_shared_head:
-                            mask_feats = self.shared_head(mask_feats, i)
-                        mask_pred = mask_head(mask_feats)
-                        segm_result = mask_head.get_seg_masks(
-                            mask_pred, _bboxes, det_labels, rcnn_test_cfg,
-                            ori_shape, scale_factor, rescale)
-                    ms_segm_result['stage{}'.format(i)] = segm_result
-
             if i < self.num_stages - 1:
                 bbox_label = cls_score.argmax(dim=1)
                 rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
@@ -425,20 +390,10 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                     ori_shape, scale_factor, rescale)
             ms_segm_result['ensemble'] = segm_result
 
-        if not self.test_cfg.keep_all_stages:
-            if self.with_mask:
-                results = (ms_bbox_result['ensemble'],
-                           ms_segm_result['ensemble'])
-            else:
-                results = ms_bbox_result['ensemble']
+        if self.with_mask:
+            results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
         else:
-            if self.with_mask:
-                results = {
-                    stage: (ms_bbox_result[stage], ms_segm_result[stage])
-                    for stage in ms_bbox_result
-                }
-            else:
-                results = ms_bbox_result
+            results = ms_bbox_result['ensemble']
 
         return results
 
diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py
index 097d109..a989e17 100644
--- a/mmdet/models/detectors/htc.py
+++ b/mmdet/models/detectors/htc.py
@@ -334,35 +334,6 @@ class HybridTaskCascade(CascadeRCNN):
                 i, x, rois, semantic_feat=semantic_feat)
             ms_scores.append(cls_score)
 
-            if self.test_cfg.keep_all_stages:
-                det_bboxes, det_labels = bbox_head.get_det_bboxes(
-                    rois,
-                    cls_score,
-                    bbox_pred,
-                    img_shape,
-                    scale_factor,
-                    rescale=rescale,
-                    cfg=rcnn_test_cfg)
-                bbox_result = bbox2result(det_bboxes, det_labels,
-                                          bbox_head.num_classes)
-                ms_bbox_result['stage{}'.format(i)] = bbox_result
-
-                if self.with_mask:
-                    mask_head = self.mask_head[i]
-                    if det_bboxes.shape[0] == 0:
-                        mask_classes = mask_head.num_classes - 1
-                        segm_result = [[] for _ in range(mask_classes)]
-                    else:
-                        _bboxes = (
-                            det_bboxes[:, :4] *
-                            scale_factor if rescale else det_bboxes)
-                        mask_pred = self._mask_forward_test(
-                            i, x, _bboxes, semantic_feat=semantic_feat)
-                        segm_result = mask_head.get_seg_masks(
-                            mask_pred, _bboxes, det_labels, rcnn_test_cfg,
-                            ori_shape, scale_factor, rescale)
-                    ms_segm_result['stage{}'.format(i)] = segm_result
-
             if i < self.num_stages - 1:
                 bbox_label = cls_score.argmax(dim=1)
                 rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
@@ -415,20 +386,10 @@ class HybridTaskCascade(CascadeRCNN):
                     ori_shape, scale_factor, rescale)
             ms_segm_result['ensemble'] = segm_result
 
-        if not self.test_cfg.keep_all_stages:
-            if self.with_mask:
-                results = (ms_bbox_result['ensemble'],
-                           ms_segm_result['ensemble'])
-            else:
-                results = ms_bbox_result['ensemble']
+        if self.with_mask:
+            results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble'])
         else:
-            if self.with_mask:
-                results = {
-                    stage: (ms_bbox_result[stage], ms_segm_result[stage])
-                    for stage in ms_bbox_result
-                }
-            else:
-                results = ms_bbox_result
+            results = ms_bbox_result['ensemble']
 
         return results
 
-- 
GitLab