From 7c2b8148c8b2e8b4fa4d0692055153e5f4449072 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Wed, 10 Oct 2018 21:50:57 +0800
Subject: [PATCH] add an argument to specify process per gpu

---
 tools/test.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/tools/test.py b/tools/test.py
index 8552561..dc8dc5e 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -39,7 +39,13 @@ def parse_args():
     parser = argparse.ArgumentParser(description='MMDet test detector')
     parser.add_argument('config', help='test config file path')
     parser.add_argument('checkpoint', help='checkpoint file')
-    parser.add_argument('--gpus', default=1, type=int)
+    parser.add_argument(
+        '--gpus', default=1, type=int, help='GPU number used for testing')
+    parser.add_argument(
+        '--proc_per_gpu',
+        default=1,
+        type=int,
+        help='Number of processes per GPU')
     parser.add_argument('--out', help='output result file')
     parser.add_argument(
         '--eval',
@@ -81,8 +87,14 @@ def main():
         model_args = cfg.model.copy()
         model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
         model_type = getattr(detectors, model_args.pop('type'))
-        outputs = parallel_test(model_type, model_args, args.checkpoint,
-                                dataset, _data_func, range(args.gpus))
+        outputs = parallel_test(
+            model_type,
+            model_args,
+            args.checkpoint,
+            dataset,
+            _data_func,
+            range(args.gpus),
+            workers_per_gpu=args.proc_per_gpu)
 
     if args.out:
         print('writing results to {}'.format(args.out))
-- 
GitLab