From 4e1fd9bdbd9d25924bead494c46bb416f1be8d52 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Mon, 26 Nov 2018 19:00:50 +0800 Subject: [PATCH] allow test multi-stage results --- mmdet/models/detectors/cascade_rcnn.py | 21 ++++++++++++++------- tools/test.py | 15 ++++++++++++--- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 16ad5fe..91d3eaf 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -248,7 +248,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, img_meta[0]) - cls_score = sum(ms_scores) / len(ms_scores) + cls_score = sum(ms_scores) / self.num_stages det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( rois, cls_score, @@ -286,14 +286,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ms_segm_result['ensemble'] = segm_result if not self.test_cfg.keep_all_stages: - ms_bbox_result = ms_bbox_result['ensemble'] if self.with_mask: - ms_segm_result = ms_segm_result['ensemble'] - - if not self.with_mask: - return ms_bbox_result + results = (ms_bbox_result['ensemble'], + ms_segm_result['ensemble']) + else: + results = ms_bbox_result['ensemble'] else: - return ms_bbox_result, ms_segm_result + if self.with_mask: + results = { + stage: (ms_bbox_result[stage], ms_segm_result[stage]) + for stage in ms_bbox_result + } + else: + results = ms_bbox_result + + return results def aug_test(self, img, img_meta, proposals=None, rescale=False): raise NotImplementedError diff --git a/tools/test.py b/tools/test.py index dc8dc5e..9a599c2 100644 --- a/tools/test.py +++ b/tools/test.py @@ -104,10 +104,19 @@ def main(): print('Starting evaluate {}'.format(' and '.join(eval_types))) if eval_types == ['proposal_fast']: result_file = args.out + coco_eval(result_file, eval_types, dataset.coco) else: - result_file = args.out + '.json' - results2json(dataset, outputs, result_file) - coco_eval(result_file, eval_types, dataset.coco) + if not isinstance(outputs[0], dict): + result_file = args.out + '.json' + results2json(dataset, outputs, result_file) + coco_eval(result_file, eval_types, dataset.coco) + else: + for name in outputs[0]: + print('\nEvaluating {}'.format(name)) + outputs_ = [out[name] for out in outputs] + result_file = args.out + '.{}.json'.format(name) + results2json(dataset, outputs_, result_file) + coco_eval(result_file, eval_types, dataset.coco) if __name__ == '__main__': -- GitLab