From 1f3e27344b59788fa7debb174efdfc041b3612e2 Mon Sep 17 00:00:00 2001
From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Date: Fri, 25 Oct 2019 19:47:37 +0800
Subject: [PATCH] Support to view AP for each class (#1549)

* also support to view ap

* change string format

* eval class_wise in coco_eval

* reformat

* class_wise API from detectron

* reformat

* change code source

* reformat, use terminaltable
---
 mmdet/core/evaluation/coco_utils.py | 39 ++++++++++++++++++++++++++++-
 tools/coco_eval.py                  |  4 ++-
 2 files changed, 41 insertions(+), 2 deletions(-)

diff --git a/mmdet/core/evaluation/coco_utils.py b/mmdet/core/evaluation/coco_utils.py
index f6f5ac0..ef44940 100644
--- a/mmdet/core/evaluation/coco_utils.py
+++ b/mmdet/core/evaluation/coco_utils.py
@@ -1,12 +1,19 @@
+import itertools
+
 import mmcv
 import numpy as np
 from pycocotools.coco import COCO
 from pycocotools.cocoeval import COCOeval
+from terminaltables import AsciiTable
 
 from .recall import eval_recalls
 
 
-def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)):
+def coco_eval(result_files,
+              result_types,
+              coco,
+              max_dets=(100, 300, 1000),
+              classwise=False):
     for res_type in result_types:
         assert res_type in [
             'proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'
@@ -43,6 +50,36 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000)):
         cocoEval.accumulate()
         cocoEval.summarize()
 
+        if classwise:
+            # Compute per-category AP
+            # from https://github.com/facebookresearch/detectron2/blob/03064eb5bafe4a3e5750cc7a16672daf5afe8435/detectron2/evaluation/coco_evaluation.py#L259-L283 # noqa
+            precisions = cocoEval.eval['precision']
+            catIds = coco.getCatIds()
+            # precision has dims (iou, recall, cls, area range, max dets)
+            assert len(catIds) == precisions.shape[2]
+
+            results_per_category = []
+            for idx, catId in enumerate(catIds):
+                # area range index 0: all area ranges
+                # max dets index -1: typically 100 per image
+                nm = coco.loadCats(catId)[0]
+                precision = precisions[:, :, idx, 0, -1]
+                precision = precision[precision > -1]
+                ap = np.mean(precision) if precision.size else float('nan')
+                results_per_category.append(
+                    ('{}'.format(nm['name']),
+                     '{:0.3f}'.format(float(ap * 100))))
+
+            N_COLS = min(6, len(results_per_category) * 2)
+            results_flatten = list(itertools.chain(*results_per_category))
+            headers = ['category', 'AP'] * (N_COLS // 2)
+            results_2d = itertools.zip_longest(
+                *[results_flatten[i::N_COLS] for i in range(N_COLS)])
+            table_data = [headers]
+            table_data += [result for result in results_2d]
+            table = AsciiTable(table_data)
+            print(table.table)
+
 
 def fast_eval_recall(results,
                      coco,
diff --git a/tools/coco_eval.py b/tools/coco_eval.py
index 65e114c..bc3c96b 100644
--- a/tools/coco_eval.py
+++ b/tools/coco_eval.py
@@ -20,8 +20,10 @@ def main():
         nargs='+',
         default=[100, 300, 1000],
         help='proposal numbers, only used for recall evaluation')
+    parser.add_argument(
+        '--classwise', action='store_true', help='whether eval class wise ap')
     args = parser.parse_args()
-    coco_eval(args.result, args.types, args.ann, args.max_dets)
+    coco_eval(args.result, args.types, args.ann, args.max_dets, args.classwise)
 
 
 if __name__ == '__main__':
-- 
GitLab