From af55b977619d6b93aa844a5152b12fe441a8b94d Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Fri, 23 Nov 2018 23:01:33 +0800
Subject: [PATCH] use dict to save multi-stage results

---
 configs/cascade_mask_rcnn_r50_fpn_1x.py |  2 +-
 configs/cascade_rcnn_r50_fpn_1x.py      |  2 +-
 mmdet/models/detectors/cascade_rcnn.py  | 39 +++++++++++++++----------
 3 files changed, 25 insertions(+), 18 deletions(-)

diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py
index ccda54b..9f3f8b6 100644
--- a/configs/cascade_mask_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py
@@ -142,7 +142,7 @@ train_cfg = dict(
             pos_weight=-1,
             debug=False)
     ],
-    loss_weight=[1, 0.5, 0.4])
+    stage_loss_weights=[1, 0.5, 0.25])
 test_cfg = dict(
     rpn=dict(
         nms_across_levels=False,
diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py
index 4b4fe16..5b4a70c 100644
--- a/configs/cascade_rcnn_r50_fpn_1x.py
+++ b/configs/cascade_rcnn_r50_fpn_1x.py
@@ -128,7 +128,7 @@ train_cfg = dict(
             pos_weight=-1,
             debug=False)
     ],
-    loss_weight=[1, 0.5, 0.4])
+    stage_loss_weights=[1, 0.5, 0.25])
 test_cfg = dict(
     rpn=dict(
         nms_across_levels=False,
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 3d4a712..16ad5fe 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -1,3 +1,5 @@
+from __future__ import division
+
 import torch
 import torch.nn as nn
 
@@ -127,7 +129,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
 
         for i in range(self.num_stages):
             rcnn_train_cfg = self.train_cfg.rcnn[i]
-            lw = self.train_cfg.loss_weight[i]
+            lw = self.train_cfg.stage_loss_weights[i]
 
             # assign gts and sample proposals
             assign_results, sampling_results = multi_apply(
@@ -193,8 +195,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
         scale_factor = img_meta[0]['scale_factor']
 
         # "ms" in variable names means multi-stage
-        ms_bbox_result = []
-        ms_segm_result = []
+        ms_bbox_result = {}
+        ms_segm_result = {}
         ms_scores = []
         rcnn_test_cfg = self.test_cfg.rcnn
 
@@ -219,11 +221,11 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                     nms_cfg=rcnn_test_cfg)
                 bbox_result = bbox2result(det_bboxes, det_labels,
                                           bbox_head.num_classes)
-                ms_bbox_result.append(bbox_result)
+                ms_bbox_result['stage{}'.format(i)] = bbox_result
 
                 if self.with_mask:
-                    mask_block = self.mask_blocks[i]
-                    mask_head = self.mask_heads[i]
+                    mask_roi_extractor = self.mask_roi_extractor[i]
+                    mask_head = self.mask_head[i]
                     if det_bboxes.shape[0] == 0:
                         segm_result = [
                             [] for _ in range(mask_head.num_classes - 1)
@@ -232,20 +234,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                         _bboxes = (det_bboxes[:, :4] * scale_factor
                                    if rescale else det_bboxes)
                         mask_rois = bbox2roi([_bboxes])
-                        mask_feats = mask_block(
-                            x[:len(mask_block.featmap_strides)], mask_rois)
+                        mask_feats = mask_roi_extractor(
+                            x[:len(mask_roi_extractor.featmap_strides)],
+                            mask_rois)
                         mask_pred = mask_head(mask_feats)
                         segm_result = mask_head.get_seg_masks(
                             mask_pred, _bboxes, det_labels, rcnn_test_cfg,
                             ori_shape, scale_factor, rescale)
-                    ms_segm_result.append(segm_result)
+                    ms_segm_result['stage{}'.format(i)] = segm_result
 
             if i < self.num_stages - 1:
                 bbox_label = cls_score.argmax(dim=1)
                 rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
                                                   img_meta[0])
 
-        cls_score = sum(ms_scores) / float(len(ms_scores))
+        cls_score = sum(ms_scores) / len(ms_scores)
         det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
             rois,
             cls_score,
@@ -256,7 +259,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             nms_cfg=rcnn_test_cfg)
         bbox_result = bbox2result(det_bboxes, det_labels,
                                   self.bbox_head[-1].num_classes)
-        ms_bbox_result.append(bbox_result)
+        ms_bbox_result['ensemble'] = bbox_result
 
         if self.with_mask:
             if det_bboxes.shape[0] == 0:
@@ -280,12 +283,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                 segm_result = self.mask_head[-1].get_seg_masks(
                     merged_masks, _bboxes, det_labels, rcnn_test_cfg,
                     ori_shape, scale_factor, rescale)
-            ms_segm_result.append(segm_result)
+            ms_segm_result['ensemble'] = segm_result
 
         if not self.test_cfg.keep_all_stages:
-            ms_bbox_result = ms_bbox_result[0]
+            ms_bbox_result = ms_bbox_result['ensemble']
             if self.with_mask:
-                ms_segm_result = ms_segm_result[0]
+                ms_segm_result = ms_segm_result['ensemble']
 
         if not self.with_mask:
             return ms_bbox_result
@@ -301,5 +304,9 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             ms_bbox_result, ms_segm_result = result
         else:
             ms_bbox_result = result
-        super(CascadeRCNN, self).show_result(data, ms_bbox_result[-1],
-                                             img_norm_cfg, **kwargs)
+        if isinstance(ms_bbox_result, dict):
+            bbox_result = ms_bbox_result['ensemble']
+        else:
+            bbox_result = ms_bbox_result
+        super(CascadeRCNN, self).show_result(data, bbox_result, img_norm_cfg,
+                                             **kwargs)
-- 
GitLab