Skip to content
Snippets Groups Projects
Commit 4e1fd9bd authored by Kai Chen's avatar Kai Chen
Browse files

allow test multi-stage results

parent 22286216
No related branches found
No related tags found
No related merge requests found
...@@ -248,7 +248,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -248,7 +248,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
img_meta[0]) img_meta[0])
cls_score = sum(ms_scores) / len(ms_scores) cls_score = sum(ms_scores) / self.num_stages
det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
rois, rois,
cls_score, cls_score,
...@@ -286,14 +286,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -286,14 +286,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
ms_segm_result['ensemble'] = segm_result ms_segm_result['ensemble'] = segm_result
if not self.test_cfg.keep_all_stages: if not self.test_cfg.keep_all_stages:
ms_bbox_result = ms_bbox_result['ensemble']
if self.with_mask: if self.with_mask:
ms_segm_result = ms_segm_result['ensemble'] results = (ms_bbox_result['ensemble'],
ms_segm_result['ensemble'])
if not self.with_mask: else:
return ms_bbox_result results = ms_bbox_result['ensemble']
else: else:
return ms_bbox_result, ms_segm_result if self.with_mask:
results = {
stage: (ms_bbox_result[stage], ms_segm_result[stage])
for stage in ms_bbox_result
}
else:
results = ms_bbox_result
return results
def aug_test(self, img, img_meta, proposals=None, rescale=False): def aug_test(self, img, img_meta, proposals=None, rescale=False):
raise NotImplementedError raise NotImplementedError
......
...@@ -104,10 +104,19 @@ def main(): ...@@ -104,10 +104,19 @@ def main():
print('Starting evaluate {}'.format(' and '.join(eval_types))) print('Starting evaluate {}'.format(' and '.join(eval_types)))
if eval_types == ['proposal_fast']: if eval_types == ['proposal_fast']:
result_file = args.out result_file = args.out
coco_eval(result_file, eval_types, dataset.coco)
else: else:
result_file = args.out + '.json' if not isinstance(outputs[0], dict):
results2json(dataset, outputs, result_file) result_file = args.out + '.json'
coco_eval(result_file, eval_types, dataset.coco) results2json(dataset, outputs, result_file)
coco_eval(result_file, eval_types, dataset.coco)
else:
for name in outputs[0]:
print('\nEvaluating {}'.format(name))
outputs_ = [out[name] for out in outputs]
result_file = args.out + '.{}.json'.format(name)
results2json(dataset, outputs_, result_file)
coco_eval(result_file, eval_types, dataset.coco)
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment