diff --git a/tools/test.py b/tools/test.py index 8552561b623ec37337423de3a9a066d38192a197..dc8dc5e85ce415b5149227b0035cf1d88d70c677 100644 --- a/tools/test.py +++ b/tools/test.py @@ -39,7 +39,13 @@ def parse_args(): parser = argparse.ArgumentParser(description='MMDet test detector') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') - parser.add_argument('--gpus', default=1, type=int) + parser.add_argument( + '--gpus', default=1, type=int, help='GPU number used for testing') + parser.add_argument( + '--proc_per_gpu', + default=1, + type=int, + help='Number of processes per GPU') parser.add_argument('--out', help='output result file') parser.add_argument( '--eval', @@ -81,8 +87,14 @@ def main(): model_args = cfg.model.copy() model_args.update(train_cfg=None, test_cfg=cfg.test_cfg) model_type = getattr(detectors, model_args.pop('type')) - outputs = parallel_test(model_type, model_args, args.checkpoint, - dataset, _data_func, range(args.gpus)) + outputs = parallel_test( + model_type, + model_args, + args.checkpoint, + dataset, + _data_func, + range(args.gpus), + workers_per_gpu=args.proc_per_gpu) if args.out: print('writing results to {}'.format(args.out))