diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 16ad5feeaf2898742cf36f1013a116e667be30b1..91d3eaf4fe27a136fb1f874d66fb059388b7e364 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -248,7 +248,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, img_meta[0]) - cls_score = sum(ms_scores) / len(ms_scores) + cls_score = sum(ms_scores) / self.num_stages det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( rois, cls_score, @@ -286,14 +286,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ms_segm_result['ensemble'] = segm_result if not self.test_cfg.keep_all_stages: - ms_bbox_result = ms_bbox_result['ensemble'] if self.with_mask: - ms_segm_result = ms_segm_result['ensemble'] - - if not self.with_mask: - return ms_bbox_result + results = (ms_bbox_result['ensemble'], + ms_segm_result['ensemble']) + else: + results = ms_bbox_result['ensemble'] else: - return ms_bbox_result, ms_segm_result + if self.with_mask: + results = { + stage: (ms_bbox_result[stage], ms_segm_result[stage]) + for stage in ms_bbox_result + } + else: + results = ms_bbox_result + + return results def aug_test(self, img, img_meta, proposals=None, rescale=False): raise NotImplementedError diff --git a/tools/test.py b/tools/test.py index dc8dc5e85ce415b5149227b0035cf1d88d70c677..9a599c2d923f5e6d999363d21f120c0b38f71395 100644 --- a/tools/test.py +++ b/tools/test.py @@ -104,10 +104,19 @@ def main(): print('Starting evaluate {}'.format(' and '.join(eval_types))) if eval_types == ['proposal_fast']: result_file = args.out + coco_eval(result_file, eval_types, dataset.coco) else: - result_file = args.out + '.json' - results2json(dataset, outputs, result_file) - coco_eval(result_file, eval_types, dataset.coco) + if not isinstance(outputs[0], dict): + result_file = args.out + '.json' + results2json(dataset, outputs, result_file) + coco_eval(result_file, eval_types, dataset.coco) + else: + for name in outputs[0]: + print('\nEvaluating {}'.format(name)) + outputs_ = [out[name] for out in outputs] + result_file = args.out + '.{}.json'.format(name) + results2json(dataset, outputs_, result_file) + coco_eval(result_file, eval_types, dataset.coco) if __name__ == '__main__':