From 976629d4e604f07280eb621b8fc7e6023fc94a90 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Mon, 17 Jun 2019 14:48:01 +0800
Subject: [PATCH] support segm evaluation using different score from bbox det

---
 mmdet/core/evaluation/coco_utils.py | 50 ++++++++++++++++++++++-------
 mmdet/core/evaluation/eval_hooks.py | 13 ++++----
 tools/test.py                       | 12 +++----
 3 files changed, 52 insertions(+), 23 deletions(-)

diff --git a/mmdet/core/evaluation/coco_utils.py b/mmdet/core/evaluation/coco_utils.py
index 0ed056b..3022ad0 100644
--- a/mmdet/core/evaluation/coco_utils.py
+++ b/mmdet/core/evaluation/coco_utils.py
@@ -6,7 +6,7 @@ from pycocotools.cocoeval import COCOeval
 from .recall import eval_recalls
 
 
-def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)):
+def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)):
     for res_type in result_types:
         assert res_type in [
             'proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'
@@ -17,16 +17,17 @@ def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)):
     assert isinstance(coco, COCO)
 
     if result_types == ['proposal_fast']:
-        ar = fast_eval_recall(result_file, coco, np.array(max_dets))
+        ar = fast_eval_recall(result_files, coco, np.array(max_dets))
         for i, num in enumerate(max_dets):
             print('AR@{}\t= {:.4f}'.format(num, ar[i]))
         return
 
-    assert result_file.endswith('.json')
-    coco_dets = coco.loadRes(result_file)
-
-    img_ids = coco.getImgIds()
     for res_type in result_types:
+        result_file = result_files[res_type]
+        assert result_file.endswith('.json')
+
+        coco_dets = coco.loadRes(result_file)
+        img_ids = coco.getImgIds()
         iou_type = 'bbox' if res_type == 'proposal' else res_type
         cocoEval = COCOeval(coco, coco_dets, iou_type)
         cocoEval.params.imgIds = img_ids
@@ -118,32 +119,59 @@ def det2json(dataset, results):
 
 
 def segm2json(dataset, results):
-    json_results = []
+    bbox_json_results = []
+    segm_json_results = []
     for idx in range(len(dataset)):
         img_id = dataset.img_ids[idx]
         det, seg = results[idx]
         for label in range(len(det)):
+            # bbox results
             bboxes = det[label]
-            segms = seg[label]
             for i in range(bboxes.shape[0]):
                 data = dict()
                 data['image_id'] = img_id
                 data['bbox'] = xyxy2xywh(bboxes[i])
                 data['score'] = float(bboxes[i][4])
                 data['category_id'] = dataset.cat_ids[label]
+                bbox_json_results.append(data)
+
+            # segm results
+            # some detectors use different score for det and segm
+            if len(seg) == 2:
+                segms = seg[0][label]
+                mask_score = seg[1][label]
+            else:
+                segms = seg[label]
+                mask_score = [bbox[4] for bbox in bboxes]
+            for i in range(bboxes.shape[0]):
+                data = dict()
+                data['image_id'] = img_id
+                data['score'] = float(mask_score[i])
+                data['category_id'] = dataset.cat_ids[label]
                 segms[i]['counts'] = segms[i]['counts'].decode()
                 data['segmentation'] = segms[i]
-                json_results.append(data)
-    return json_results
+                segm_json_results.append(data)
+    return bbox_json_results, segm_json_results
 
 
 def results2json(dataset, results, out_file):
+    result_files = dict()
     if isinstance(results[0], list):
         json_results = det2json(dataset, results)
+        result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
+        result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
+        mmcv.dump(json_results, result_files['bbox'])
     elif isinstance(results[0], tuple):
         json_results = segm2json(dataset, results)
+        result_files['bbox'] = '{}.{}.json'.format(out_file, 'bbox')
+        result_files['proposal'] = '{}.{}.json'.format(out_file, 'bbox')
+        result_files['segm'] = '{}.{}.json'.format(out_file, 'segm')
+        mmcv.dump(json_results[0], result_files['bbox'])
+        mmcv.dump(json_results[1], result_files['segm'])
     elif isinstance(results[0], np.ndarray):
         json_results = proposal2json(dataset, results)
+        result_files['proposal'] = '{}.{}.json'.format(out_file, 'proposal')
+        mmcv.dump(json_results, result_files['proposal'])
     else:
         raise TypeError('invalid type of results')
-    mmcv.dump(json_results, out_file)
+    return result_files
diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py
index fb12578..806067d 100644
--- a/mmdet/core/evaluation/eval_hooks.py
+++ b/mmdet/core/evaluation/eval_hooks.py
@@ -135,15 +135,15 @@ class CocoDistEvalRecallHook(DistEvalHook):
 class CocoDistEvalmAPHook(DistEvalHook):
 
     def evaluate(self, runner, results):
-        tmp_file = osp.join(runner.work_dir, 'temp_0.json')
-        results2json(self.dataset, results, tmp_file)
+        tmp_file = osp.join(runner.work_dir, 'temp_0')
+        result_files = results2json(self.dataset, results, tmp_file)
 
-        res_types = ['bbox',
-                     'segm'] if runner.model.module.with_mask else ['bbox']
+        res_types = ['bbox', 'segm'
+                     ] if runner.model.module.with_mask else ['bbox']
         cocoGt = self.dataset.coco
-        cocoDt = cocoGt.loadRes(tmp_file)
         imgIds = cocoGt.getImgIds()
         for res_type in res_types:
+            cocoDt = cocoGt.loadRes(result_files[res_type])
             iou_type = res_type
             cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
             cocoEval.params.imgIds = imgIds
@@ -159,4 +159,5 @@ class CocoDistEvalmAPHook(DistEvalHook):
                 '{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
                 '{ap[4]:.3f} {ap[5]:.3f}').format(ap=cocoEval.stats[:6])
         runner.log_buffer.ready = True
-        os.remove(tmp_file)
+        for res_type in res_types:
+            os.remove(result_files[res_type])
diff --git a/tools/test.py b/tools/test.py
index af950aa..df629f3 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -184,16 +184,16 @@ def main():
                 coco_eval(result_file, eval_types, dataset.coco)
             else:
                 if not isinstance(outputs[0], dict):
-                    result_file = args.out + '.json'
-                    results2json(dataset, outputs, result_file)
-                    coco_eval(result_file, eval_types, dataset.coco)
+                    result_files = results2json(dataset, outputs, args.out)
+                    coco_eval(result_files, eval_types, dataset.coco)
                 else:
                     for name in outputs[0]:
                         print('\nEvaluating {}'.format(name))
                         outputs_ = [out[name] for out in outputs]
-                        result_file = args.out + '.{}.json'.format(name)
-                        results2json(dataset, outputs_, result_file)
-                        coco_eval(result_file, eval_types, dataset.coco)
+                        result_file = args.out + '.{}'.format(name)
+                        result_files = results2json(dataset, outputs_,
+                                                    result_file)
+                        coco_eval(result_files, eval_types, dataset.coco)
 
 
 if __name__ == '__main__':
-- 
GitLab