From 4e1fd9bdbd9d25924bead494c46bb416f1be8d52 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Mon, 26 Nov 2018 19:00:50 +0800
Subject: [PATCH] allow test multi-stage results

---
 mmdet/models/detectors/cascade_rcnn.py | 21 ++++++++++++++-------
 tools/test.py                          | 15 ++++++++++++---
 2 files changed, 26 insertions(+), 10 deletions(-)

diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 16ad5fe..91d3eaf 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -248,7 +248,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                 rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
                                                   img_meta[0])
 
-        cls_score = sum(ms_scores) / len(ms_scores)
+        cls_score = sum(ms_scores) / self.num_stages
         det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
             rois,
             cls_score,
@@ -286,14 +286,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             ms_segm_result['ensemble'] = segm_result
 
         if not self.test_cfg.keep_all_stages:
-            ms_bbox_result = ms_bbox_result['ensemble']
             if self.with_mask:
-                ms_segm_result = ms_segm_result['ensemble']
-
-        if not self.with_mask:
-            return ms_bbox_result
+                results = (ms_bbox_result['ensemble'],
+                           ms_segm_result['ensemble'])
+            else:
+                results = ms_bbox_result['ensemble']
         else:
-            return ms_bbox_result, ms_segm_result
+            if self.with_mask:
+                results = {
+                    stage: (ms_bbox_result[stage], ms_segm_result[stage])
+                    for stage in ms_bbox_result
+                }
+            else:
+                results = ms_bbox_result
+
+        return results
 
     def aug_test(self, img, img_meta, proposals=None, rescale=False):
         raise NotImplementedError
diff --git a/tools/test.py b/tools/test.py
index dc8dc5e..9a599c2 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -104,10 +104,19 @@ def main():
             print('Starting evaluate {}'.format(' and '.join(eval_types)))
             if eval_types == ['proposal_fast']:
                 result_file = args.out
+                coco_eval(result_file, eval_types, dataset.coco)
             else:
-                result_file = args.out + '.json'
-                results2json(dataset, outputs, result_file)
-            coco_eval(result_file, eval_types, dataset.coco)
+                if not isinstance(outputs[0], dict):
+                    result_file = args.out + '.json'
+                    results2json(dataset, outputs, result_file)
+                    coco_eval(result_file, eval_types, dataset.coco)
+                else:
+                    for name in outputs[0]:
+                        print('\nEvaluating {}'.format(name))
+                        outputs_ = [out[name] for out in outputs]
+                        result_file = args.out + '.{}.json'.format(name)
+                        results2json(dataset, outputs_, result_file)
+                        coco_eval(result_file, eval_types, dataset.coco)
 
 
 if __name__ == '__main__':
-- 
GitLab