From bc94b21257a787ddb5c29337d8a3965cd742cf33 Mon Sep 17 00:00:00 2001
From: Demetris Marnerides <dmarnerides@gmail.com>
Date: Wed, 7 Aug 2019 06:10:19 +0100
Subject: [PATCH] Allowing validation dataset for computing validation loss
 (#1093)

* Allowing validation dataset for computing validation loss

* added validation dataset

* fixing datasets errors

* Fixing linting errors
---
 mmdet/apis/train.py | 12 ++++++------
 tools/train.py      | 10 ++++++----
 2 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index dc51381..35ad34c 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -136,12 +136,11 @@ def build_optimizer(model, optimizer_cfg):
 
 def _dist_train(model, dataset, cfg, validate=False):
     # prepare data loaders
+    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
     data_loaders = [
         build_dataloader(
-            dataset,
-            cfg.data.imgs_per_gpu,
-            cfg.data.workers_per_gpu,
-            dist=True)
+            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
+        for ds in dataset
     ]
     # put model on gpus
     model = MMDistributedDataParallel(model.cuda())
@@ -189,13 +188,14 @@ def _dist_train(model, dataset, cfg, validate=False):
 
 def _non_dist_train(model, dataset, cfg, validate=False):
     # prepare data loaders
+    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
     data_loaders = [
         build_dataloader(
-            dataset,
+            ds,
             cfg.data.imgs_per_gpu,
             cfg.data.workers_per_gpu,
             cfg.gpus,
-            dist=False)
+            dist=False) for ds in dataset
     ]
     # put model on gpus
     model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
diff --git a/tools/train.py b/tools/train.py
index 7909cde..0854163 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -83,19 +83,21 @@ def main():
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
 
-    train_dataset = build_dataset(cfg.data.train)
+    datasets = [build_dataset(cfg.data.train)]
+    if len(cfg.workflow) == 2:
+        datasets.append(build_dataset(cfg.data.val))
     if cfg.checkpoint_config is not None:
         # save mmdet version, config file content and class names in
         # checkpoints as meta data
         cfg.checkpoint_config.meta = dict(
             mmdet_version=__version__,
             config=cfg.text,
-            CLASSES=train_dataset.CLASSES)
+            CLASSES=datasets[0].CLASSES)
     # add an attribute for visualization convenience
-    model.CLASSES = train_dataset.CLASSES
+    model.CLASSES = datasets[0].CLASSES
     train_detector(
         model,
-        train_dataset,
+        datasets,
         cfg,
         distributed=distributed,
         validate=args.validate,
-- 
GitLab