From 20762ce9e7b2394ddae4331b4bddff5a1ba8284a Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Wed, 10 Oct 2018 21:44:55 +0800
Subject: [PATCH] bug fix for proposal evaluation

---
 mmdet/core/evaluation/coco_utils.py |  4 ++--
 tools/test.py                       | 17 +++++++++++++----
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/mmdet/core/evaluation/coco_utils.py b/mmdet/core/evaluation/coco_utils.py
index e9fdb41..0ed056b 100644
--- a/mmdet/core/evaluation/coco_utils.py
+++ b/mmdet/core/evaluation/coco_utils.py
@@ -16,8 +16,8 @@ def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)):
         coco = COCO(coco)
     assert isinstance(coco, COCO)
 
-    if res_type == 'proposal_fast':
-        ar = fast_eval_recall(result_file, coco, max_dets)
+    if result_types == ['proposal_fast']:
+        ar = fast_eval_recall(result_file, coco, np.array(max_dets))
         for i, num in enumerate(max_dets):
             print('AR@{}\t= {:.4f}'.format(num, ar[i]))
         return
diff --git a/tools/test.py b/tools/test.py
index b322bb2..8552561 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -55,6 +55,9 @@ def parse_args():
 def main():
     args = parse_args()
 
+    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
+        raise ValueError('The output file must be a pkl file.')
+
     cfg = mmcv.Config.fromfile(args.config)
     cfg.model.pretrained = None
     cfg.data.test.test_mode = True
@@ -82,11 +85,17 @@ def main():
                                 dataset, _data_func, range(args.gpus))
 
     if args.out:
+        print('writing results to {}'.format(args.out))
         mmcv.dump(outputs, args.out)
-        if args.eval:
-            json_file = args.out + '.json'
-            results2json(dataset, outputs, json_file)
-            coco_eval(json_file, args.eval, dataset.coco)
+        eval_types = args.eval
+        if eval_types:
+            print('Starting evaluate {}'.format(' and '.join(eval_types)))
+            if eval_types == ['proposal_fast']:
+                result_file = args.out
+            else:
+                result_file = args.out + '.json'
+                results2json(dataset, outputs, result_file)
+            coco_eval(result_file, eval_types, dataset.coco)
 
 
 if __name__ == '__main__':
-- 
GitLab