diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index dc51381815174c5d5b1366642d0c0fc7aa126755..35ad34cdff84a1c20b68c8b329d6613d2f625c5b 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -136,12 +136,11 @@ def build_optimizer(model, optimizer_cfg): def _dist_train(model, dataset, cfg, validate=False): # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader( - dataset, - cfg.data.imgs_per_gpu, - cfg.data.workers_per_gpu, - dist=True) + ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True) + for ds in dataset ] # put model on gpus model = MMDistributedDataParallel(model.cuda()) @@ -189,13 +188,14 @@ def _dist_train(model, dataset, cfg, validate=False): def _non_dist_train(model, dataset, cfg, validate=False): # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader( - dataset, + ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, cfg.gpus, - dist=False) + dist=False) for ds in dataset ] # put model on gpus model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() diff --git a/tools/train.py b/tools/train.py index 7909cde157769993cb1c4d789af67ba4b666042c..085416323331c22d8ea4c84efc4a0697bc9514d7 100644 --- a/tools/train.py +++ b/tools/train.py @@ -83,19 +83,21 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - train_dataset = build_dataset(cfg.data.train) + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + datasets.append(build_dataset(cfg.data.val)) if cfg.checkpoint_config is not None: # save mmdet version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict( mmdet_version=__version__, config=cfg.text, - CLASSES=train_dataset.CLASSES) + CLASSES=datasets[0].CLASSES) # add an attribute for visualization convenience - model.CLASSES = train_dataset.CLASSES + model.CLASSES = datasets[0].CLASSES train_detector( model, - train_dataset, + datasets, cfg, distributed=distributed, validate=args.validate,