From bc94b21257a787ddb5c29337d8a3965cd742cf33 Mon Sep 17 00:00:00 2001 From: Demetris Marnerides <dmarnerides@gmail.com> Date: Wed, 7 Aug 2019 06:10:19 +0100 Subject: [PATCH] Allowing validation dataset for computing validation loss (#1093) * Allowing validation dataset for computing validation loss * added validation dataset * fixing datasets errors * Fixing linting errors --- mmdet/apis/train.py | 12 ++++++------ tools/train.py | 10 ++++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index dc51381..35ad34c 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -136,12 +136,11 @@ def build_optimizer(model, optimizer_cfg): def _dist_train(model, dataset, cfg, validate=False): # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader( - dataset, - cfg.data.imgs_per_gpu, - cfg.data.workers_per_gpu, - dist=True) + ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True) + for ds in dataset ] # put model on gpus model = MMDistributedDataParallel(model.cuda()) @@ -189,13 +188,14 @@ def _dist_train(model, dataset, cfg, validate=False): def _non_dist_train(model, dataset, cfg, validate=False): # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] data_loaders = [ build_dataloader( - dataset, + ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, cfg.gpus, - dist=False) + dist=False) for ds in dataset ] # put model on gpus model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() diff --git a/tools/train.py b/tools/train.py index 7909cde..0854163 100644 --- a/tools/train.py +++ b/tools/train.py @@ -83,19 +83,21 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - train_dataset = build_dataset(cfg.data.train) + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + datasets.append(build_dataset(cfg.data.val)) if cfg.checkpoint_config is not None: # save mmdet version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict( mmdet_version=__version__, config=cfg.text, - CLASSES=train_dataset.CLASSES) + CLASSES=datasets[0].CLASSES) # add an attribute for visualization convenience - model.CLASSES = train_dataset.CLASSES + model.CLASSES = datasets[0].CLASSES train_detector( model, - train_dataset, + datasets, cfg, distributed=distributed, validate=args.validate, -- GitLab