diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index dc51381815174c5d5b1366642d0c0fc7aa126755..35ad34cdff84a1c20b68c8b329d6613d2f625c5b 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -136,12 +136,11 @@ def build_optimizer(model, optimizer_cfg):
 
 def _dist_train(model, dataset, cfg, validate=False):
     # prepare data loaders
+    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
     data_loaders = [
         build_dataloader(
-            dataset,
-            cfg.data.imgs_per_gpu,
-            cfg.data.workers_per_gpu,
-            dist=True)
+            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
+        for ds in dataset
     ]
     # put model on gpus
     model = MMDistributedDataParallel(model.cuda())
@@ -189,13 +188,14 @@ def _dist_train(model, dataset, cfg, validate=False):
 
 def _non_dist_train(model, dataset, cfg, validate=False):
     # prepare data loaders
+    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
     data_loaders = [
         build_dataloader(
-            dataset,
+            ds,
             cfg.data.imgs_per_gpu,
             cfg.data.workers_per_gpu,
             cfg.gpus,
-            dist=False)
+            dist=False) for ds in dataset
     ]
     # put model on gpus
     model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
diff --git a/tools/train.py b/tools/train.py
index 7909cde157769993cb1c4d789af67ba4b666042c..085416323331c22d8ea4c84efc4a0697bc9514d7 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -83,19 +83,21 @@ def main():
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
 
-    train_dataset = build_dataset(cfg.data.train)
+    datasets = [build_dataset(cfg.data.train)]
+    if len(cfg.workflow) == 2:
+        datasets.append(build_dataset(cfg.data.val))
     if cfg.checkpoint_config is not None:
         # save mmdet version, config file content and class names in
         # checkpoints as meta data
         cfg.checkpoint_config.meta = dict(
             mmdet_version=__version__,
             config=cfg.text,
-            CLASSES=train_dataset.CLASSES)
+            CLASSES=datasets[0].CLASSES)
     # add an attribute for visualization convenience
-    model.CLASSES = train_dataset.CLASSES
+    model.CLASSES = datasets[0].CLASSES
     train_detector(
         model,
-        train_dataset,
+        datasets,
         cfg,
         distributed=distributed,
         validate=args.validate,