diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index c146b04feab9e354fb7953a253a2711e7c778ca2..a79537f70dac3209c73ad158ffaff1469310d9f0 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -17,13 +17,6 @@ from mmdet.datasets import DATASETS, build_dataloader from mmdet.models import RPN -def set_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - def get_root_logger(log_file=None, log_level=logging.INFO): logger = logging.getLogger('mmdet') # if the logger has been initialized, just return it @@ -45,6 +38,25 @@ def get_root_logger(log_file=None, log_level=logging.INFO): return logger +def set_random_seed(seed, deterministic=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def parse_losses(losses): log_vars = OrderedDict() for loss_name, loss_value in losses.items(): @@ -70,6 +82,21 @@ def parse_losses(losses): def batch_processor(model, data, train_mode): + """Process a data batch. + + This method is required as an argument of Runner, which defines how to + process a data batch and obtain proper outputs. The first 3 arguments of + batch_processor are fixed. + + Args: + model (nn.Module): A PyTorch model. + data (dict): The data batch in a dict. + train_mode (bool): Training mode or not. It may be useless for some + models. + + Returns: + dict: A dict containing losses and log vars. + """ losses = model(**data) loss, log_vars = parse_losses(losses) diff --git a/mmdet/datasets/loader/build_loader.py b/mmdet/datasets/loader/build_loader.py index 32275b41ce19db088fc6340e65a1ff9bcc14894a..e9431d7bab3701dae9614dd9d4b956a7d66c3903 100644 --- a/mmdet/datasets/loader/build_loader.py +++ b/mmdet/datasets/loader/build_loader.py @@ -21,8 +21,30 @@ def build_dataloader(dataset, dist=True, shuffle=True, **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of + each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ if dist: rank, world_size = get_dist_info() + # DistributedGroupSampler will definitely shuffle the data to satisfy + # that images on each GPU are in the same group if shuffle: sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size, rank) diff --git a/tools/train.py b/tools/train.py index 5958d2409b810c344c867a59ff173a58d8b881d3..2931a28c92858bdd3cbb4fe866e7a291bd5f7f7d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -32,6 +32,10 @@ def parse_args(): help='number of gpus to use ' '(only applicable to non-distributed training)') parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], @@ -88,8 +92,9 @@ def main(): # set random seeds if args.seed is not None: - logger.info('Set random seed to {}'.format(args.seed)) - set_random_seed(args.seed) + logger.info('Set random seed to {}, deterministic: {}'.format( + args.seed, args.deterministic)) + set_random_seed(args.seed, deterministic=args.deterministic) model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)