From 8b47a12b8d968e4e7bfaf105023671a6d05728d5 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Mon, 8 Oct 2018 13:46:19 +0800 Subject: [PATCH] minor updates for train/test scripts --- tools/test.py | 7 +++---- tools/train.py | 8 +++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tools/test.py b/tools/test.py index 4e2ecd2..c0bfd25 100644 --- a/tools/test.py +++ b/tools/test.py @@ -44,17 +44,16 @@ def parse_args(): '--eval', type=str, nargs='+', - choices=['proposal', 'bbox', 'segm', 'keypoints'], + choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'], help='eval types') parser.add_argument('--show', action='store_true', help='show results') args = parser.parse_args() return args -args = parse_args() - - def main(): + args = parse_args() + cfg = mmcv.Config.fromfile(args.config) cfg.model.pretrained = None cfg.data.test.test_mode = True diff --git a/tools/train.py b/tools/train.py index b72adeb..07a918d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -2,6 +2,7 @@ from __future__ import division import argparse import logging +import random from collections import OrderedDict import numpy as np @@ -55,6 +56,7 @@ def get_logger(log_level): def set_random_seed(seed): + random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @@ -89,7 +91,7 @@ def main(): if args.work_dir is not None: cfg.work_dir = args.work_dir cfg.gpus = args.gpus - # add mmdet version to checkpoint as meta data + # save mmdet version in checkpoint as meta data cfg.checkpoint_config.meta = dict( mmdet_version=__version__, config=cfg.text) @@ -103,13 +105,13 @@ def main(): # init distributed environment if necessary if args.launcher == 'none': dist = False - logger.info('Disabled distributed training.') + logger.info('Non-distributed training.') else: dist = True init_dist(args.launcher, **cfg.dist_params) if torch.distributed.get_rank() != 0: logger.setLevel('ERROR') - logger.info('Enabled distributed training.') + logger.info('Distributed training.') # prepare data loaders train_dataset = obj_from_dict(cfg.data.train, datasets) -- GitLab