diff --git a/tools/train.py b/tools/train.py index f596f5693d2b0ec7e09915cb35ff66a36192f3ff..dde6b06b5c3ae6e27da01e21109ccbc0b34d2337 100644 --- a/tools/train.py +++ b/tools/train.py @@ -4,6 +4,7 @@ import argparse import logging from collections import OrderedDict +import numpy as np import torch from mmcv import Config from mmcv.torchpack import Runner, obj_from_dict @@ -53,6 +54,12 @@ def get_logger(log_level): return logger +def set_random_seed(seed): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') @@ -63,6 +70,7 @@ def parse_args(): help='whether to add a validate phase') parser.add_argument( '--gpus', type=int, default=1, help='number of gpus to use') + parser.add_argument('--seed', type=int, help='random seed') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], @@ -84,6 +92,11 @@ def main(): logger = get_logger(cfg.log_level) + # set random seed if specified + if args.seed is not None: + logger.info('Set random seed to {}'.format(args.seed)) + set_random_seed(args.seed) + # init distributed environment if necessary if args.launcher == 'none': dist = False