Skip to content
Snippets Groups Projects
Commit 8b47a12b authored by Kai Chen's avatar Kai Chen
Browse files

minor updates for train/test scripts

parent f8dab59d
No related branches found
No related tags found
No related merge requests found
......@@ -44,17 +44,16 @@ def parse_args():
'--eval',
type=str,
nargs='+',
choices=['proposal', 'bbox', 'segm', 'keypoints'],
choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
help='eval types')
parser.add_argument('--show', action='store_true', help='show results')
args = parser.parse_args()
return args
args = parse_args()
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
cfg.data.test.test_mode = True
......
......@@ -2,6 +2,7 @@ from __future__ import division
import argparse
import logging
import random
from collections import OrderedDict
import numpy as np
......@@ -55,6 +56,7 @@ def get_logger(log_level):
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
......@@ -89,7 +91,7 @@ def main():
if args.work_dir is not None:
cfg.work_dir = args.work_dir
cfg.gpus = args.gpus
# add mmdet version to checkpoint as meta data
# save mmdet version in checkpoint as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text)
......@@ -103,13 +105,13 @@ def main():
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
logger.info('Disabled distributed training.')
logger.info('Non-distributed training.')
else:
dist = True
init_dist(args.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR')
logger.info('Enabled distributed training.')
logger.info('Distributed training.')
# prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment