diff --git a/tools/test.py b/tools/test.py index 0078db5a197131bc76392eb515b74a1fc4a065aa..f3a3f25304059c702db2c3413400623d9291c885 100644 --- a/tools/test.py +++ b/tools/test.py @@ -105,6 +105,10 @@ def parse_args(): parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('--out', help='output result file') + parser.add_argument( + '--json_out', + help='output result file name without extension', + type=str) parser.add_argument( '--eval', type=str, @@ -128,13 +132,16 @@ def parse_args(): def main(): args = parse_args() - assert args.out or args.show, \ + assert args.out or args.show or args.json_out, \ ('Please specify at least one operation (save or show the results) ' - 'with the argument "--out" or "--show"') + 'with the argument "--out" or "--show" or "--json_out"') if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): raise ValueError('The output file must be a pkl file.') + if args.json_out is not None and args.json_out.endswith('.json'): + args.json_out = args.json_out[:-5] + cfg = mmcv.Config.fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): @@ -202,6 +209,16 @@ def main(): result_file) coco_eval(result_files, eval_types, dataset.coco) + # Save predictions in the COCO json format + if args.json_out and rank == 0: + if not isinstance(outputs[0], dict): + results2json(dataset, outputs, args.json_out) + else: + for name in outputs[0]: + outputs_ = [out[name] for out in outputs] + result_file = args.json_out + '.{}'.format(name) + results2json(dataset, outputs_, result_file) + if __name__ == '__main__': main()