From 3dc0047c29430248047274cf415bbf6c33dbc642 Mon Sep 17 00:00:00 2001
From: Vladimir Iglovikov <ternaus@users.noreply.github.com>
Date: Wed, 24 Jul 2019 20:10:26 -0700
Subject: [PATCH] Added an option to save inference results to the json file
 (#1049)

* Added an option to save inference results to the json file

* flake8 fixes

* yapf fixes
---
 tools/test.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/tools/test.py b/tools/test.py
index 0078db5..f3a3f25 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -105,6 +105,10 @@ def parse_args():
     parser.add_argument('config', help='test config file path')
     parser.add_argument('checkpoint', help='checkpoint file')
     parser.add_argument('--out', help='output result file')
+    parser.add_argument(
+        '--json_out',
+        help='output result file name without extension',
+        type=str)
     parser.add_argument(
         '--eval',
         type=str,
@@ -128,13 +132,16 @@ def parse_args():
 def main():
     args = parse_args()
 
-    assert args.out or args.show, \
+    assert args.out or args.show or args.json_out, \
         ('Please specify at least one operation (save or show the results) '
-         'with the argument "--out" or "--show"')
+         'with the argument "--out" or "--show" or "--json_out"')
 
     if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
         raise ValueError('The output file must be a pkl file.')
 
+    if args.json_out is not None and args.json_out.endswith('.json'):
+        args.json_out = args.json_out[:-5]
+
     cfg = mmcv.Config.fromfile(args.config)
     # set cudnn_benchmark
     if cfg.get('cudnn_benchmark', False):
@@ -202,6 +209,16 @@ def main():
                                                     result_file)
                         coco_eval(result_files, eval_types, dataset.coco)
 
+    # Save predictions in the COCO json format
+    if args.json_out and rank == 0:
+        if not isinstance(outputs[0], dict):
+            results2json(dataset, outputs, args.json_out)
+        else:
+            for name in outputs[0]:
+                outputs_ = [out[name] for out in outputs]
+                result_file = args.json_out + '.{}'.format(name)
+                results2json(dataset, outputs_, result_file)
+
 
 if __name__ == '__main__':
     main()
-- 
GitLab