From 8b47a12b8d968e4e7bfaf105023671a6d05728d5 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Mon, 8 Oct 2018 13:46:19 +0800
Subject: [PATCH] minor updates for train/test scripts

---
 tools/test.py  | 7 +++----
 tools/train.py | 8 +++++---
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/tools/test.py b/tools/test.py
index 4e2ecd2..c0bfd25 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -44,17 +44,16 @@ def parse_args():
         '--eval',
         type=str,
         nargs='+',
-        choices=['proposal', 'bbox', 'segm', 'keypoints'],
+        choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
         help='eval types')
     parser.add_argument('--show', action='store_true', help='show results')
     args = parser.parse_args()
     return args
 
 
-args = parse_args()
-
-
 def main():
+    args = parse_args()
+
     cfg = mmcv.Config.fromfile(args.config)
     cfg.model.pretrained = None
     cfg.data.test.test_mode = True
diff --git a/tools/train.py b/tools/train.py
index b72adeb..07a918d 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -2,6 +2,7 @@ from __future__ import division
 
 import argparse
 import logging
+import random
 from collections import OrderedDict
 
 import numpy as np
@@ -55,6 +56,7 @@ def get_logger(log_level):
 
 
 def set_random_seed(seed):
+    random.seed(seed)
     np.random.seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
@@ -89,7 +91,7 @@ def main():
     if args.work_dir is not None:
         cfg.work_dir = args.work_dir
     cfg.gpus = args.gpus
-    # add mmdet version to checkpoint as meta data
+    # save mmdet version in checkpoint as meta data
     cfg.checkpoint_config.meta = dict(
         mmdet_version=__version__, config=cfg.text)
 
@@ -103,13 +105,13 @@ def main():
     # init distributed environment if necessary
     if args.launcher == 'none':
         dist = False
-        logger.info('Disabled distributed training.')
+        logger.info('Non-distributed training.')
     else:
         dist = True
         init_dist(args.launcher, **cfg.dist_params)
         if torch.distributed.get_rank() != 0:
             logger.setLevel('ERROR')
-        logger.info('Enabled distributed training.')
+        logger.info('Distributed training.')
 
     # prepare data loaders
     train_dataset = obj_from_dict(cfg.data.train, datasets)
-- 
GitLab