From 976629d4e604f07280eb621b8fc7e6023fc94a90 Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Mon, 17 Jun 2019 14:48:01 +0800 Subject: [PATCH] support segm evaluation using different score from bbox det --- mmdet/core/evaluation/coco_utils.py | 50 ++++++++++++++++++++++------- mmdet/core/evaluation/eval_hooks.py | 13 ++++---- tools/test.py | 12 +++---- 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/mmdet/core/evaluation/coco_utils.py b/mmdet/core/evaluation/coco_utils.py index 0ed056b..3022ad0 100644 --- a/mmdet/core/evaluation/coco_utils.py +++ b/mmdet/core/evaluation/coco_utils.py @@ -6,7 +6,7 @@ from pycocotools.cocoeval import COCOeval from .recall import eval_recalls -def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)): +def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)): for res_type in result_types: assert res_type in [ 'proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints' @@ -17,16 +17,17 @@ def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)): assert isinstance(coco, COCO) if result_types == ['proposal_fast']: - ar = fast_eval_recall(result_file, coco, np.array(max_dets)) + ar = fast_eval_recall(result_files, coco, np.array(max_dets)) for i, num in enumerate(max_dets): print('AR@{}\t= {:.4f}'.format(num, ar[i])) return - assert result_file.endswith('.json') - coco_dets = coco.loadRes(result_file) - - img_ids = coco.getImgIds() for res_type in result_types: + result_file = result_files[res_type] + assert result_file.endswith('.json') + + coco_dets = coco.loadRes(result_file) + img_ids = coco.getImgIds() iou_type = 'bbox' if res_type == 'proposal' else res_type cocoEval = COCOeval(coco, coco_dets, iou_type) cocoEval.params.imgIds = img_ids @@ -118,32 +119,59 @@ def det2json(dataset, results): def segm2json(dataset, results): - json_results = [] + bbox_json_results = [] + segm_json_results = [] for idx in range(len(dataset)): img_id = dataset.img_ids[idx] det, seg = results[idx] for label in range(len(det)): + # bbox results bboxes = det[label] - segms = seg[label] for i in range(bboxes.shape[0]): data = dict() data['image_id'] = img_id data['bbox'] = xyxy2xywh(bboxes[i]) data['score'] = float(bboxes[i][4]) data['category_id'] = dataset.cat_ids[label] + bbox_json_results.append(data) + + # segm results + # some detectors use different score for det and segm + if len(seg) == 2: + segms = seg[0][label] + mask_score = seg[1][label] + else: + segms = seg[label] + mask_score = [bbox[4] for bbox in bboxes] + for i in range(bboxes.shape[0]): + data = dict() + data['image_id'] = img_id + data['score'] = float(mask_score[i]) + data['category_id'] = dataset.cat_ids[label] segms[i]['counts'] = segms[i]['counts'].decode() data['segmentation'] = segms[i] - json_results.append(data) - return json_results + segm_json_results.append(data) + return bbox_json_results, segm_json_results def results2json(dataset, results, out_file): + result_files = dict() if isinstance(results[0], list): json_results = det2json(dataset, results) + result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') + result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox') + mmcv.dump(json_results, result_files['bbox']) elif isinstance(results[0], tuple): json_results = segm2json(dataset, results) + result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox') + result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox') + result_files['segm'] = '{}.{}.json'.format(out_file, 'segm') + mmcv.dump(json_results[0], result_files['bbox']) + mmcv.dump(json_results[1], result_files['segm']) elif isinstance(results[0], np.ndarray): json_results = proposal2json(dataset, results) + result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal') + mmcv.dump(json_results, result_files['proposal']) else: raise TypeError('invalid type of results') - mmcv.dump(json_results, out_file) + return result_files diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index fb12578..806067d 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -135,15 +135,15 @@ class CocoDistEvalRecallHook(DistEvalHook): class CocoDistEvalmAPHook(DistEvalHook): def evaluate(self, runner, results): - tmp_file = osp.join(runner.work_dir, 'temp_0.json') - results2json(self.dataset, results, tmp_file) + tmp_file = osp.join(runner.work_dir, 'temp_0') + result_files = results2json(self.dataset, results, tmp_file) - res_types = ['bbox', - 'segm'] if runner.model.module.with_mask else ['bbox'] + res_types = ['bbox', 'segm' + ] if runner.model.module.with_mask else ['bbox'] cocoGt = self.dataset.coco - cocoDt = cocoGt.loadRes(tmp_file) imgIds = cocoGt.getImgIds() for res_type in res_types: + cocoDt = cocoGt.loadRes(result_files[res_type]) iou_type = res_type cocoEval = COCOeval(cocoGt, cocoDt, iou_type) cocoEval.params.imgIds = imgIds @@ -159,4 +159,5 @@ class CocoDistEvalmAPHook(DistEvalHook): '{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' '{ap[4]:.3f} {ap[5]:.3f}').format(ap=cocoEval.stats[:6]) runner.log_buffer.ready = True - os.remove(tmp_file) + for res_type in res_types: + os.remove(result_files[res_type]) diff --git a/tools/test.py b/tools/test.py index af950aa..df629f3 100644 --- a/tools/test.py +++ b/tools/test.py @@ -184,16 +184,16 @@ def main(): coco_eval(result_file, eval_types, dataset.coco) else: if not isinstance(outputs[0], dict): - result_file = args.out + '.json' - results2json(dataset, outputs, result_file) - coco_eval(result_file, eval_types, dataset.coco) + result_files = results2json(dataset, outputs, args.out) + coco_eval(result_files, eval_types, dataset.coco) else: for name in outputs[0]: print('\nEvaluating {}'.format(name)) outputs_ = [out[name] for out in outputs] - result_file = args.out + '.{}.json'.format(name) - results2json(dataset, outputs_, result_file) - coco_eval(result_file, eval_types, dataset.coco) + result_file = args.out + '.{}'.format(name) + result_files = results2json(dataset, outputs_, + result_file) + coco_eval(result_files, eval_types, dataset.coco) if __name__ == '__main__': -- GitLab