diff --git a/.gitignore b/.gitignore
index 01c47d6e277dba0d7b880dff88f9695f9a8eec50..f189e1d5b6851047dc4230032c134f8134a8d73a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -107,3 +107,5 @@ venv.bak/
 mmdet/ops/nms/*.cpp
 mmdet/version.py
 data
+.vscode
+.idea
diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
index 75e07097756bd014bbef17294b6803aa83621fd1..dda1a1eb4a514b93254623693753ecf8f3839c12 100644
--- a/mmdet/datasets/__init__.py
+++ b/mmdet/datasets/__init__.py
@@ -1,9 +1,11 @@
 from .custom import CustomDataset
 from .coco import CocoDataset
 from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
-from .utils import to_tensor, random_scale, show_ann
+from .utils import to_tensor, random_scale, show_ann, get_dataset
+from .concat_dataset import ConcatDataset
 
 __all__ = [
     'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
-    'build_dataloader', 'to_tensor', 'random_scale', 'show_ann'
+    'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
+    'show_ann', 'get_dataset'
 ]
diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42b60982578550774ab930832f7ee1014ef672f
--- /dev/null
+++ b/mmdet/datasets/concat_dataset.py
@@ -0,0 +1,20 @@
+import numpy as np
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+
+class ConcatDataset(_ConcatDataset):
+    """
+    Same as torch.utils.data.dataset.ConcatDataset, but
+    concat the group flag for image aspect ratio.
+    """
+    def __init__(self, datasets):
+        """
+        flag: Images with aspect ratio greater than 1 will be set as group 1,
+              otherwise group 0.
+        """
+        super(ConcatDataset, self).__init__(datasets)
+        if hasattr(datasets[0], 'flag'):
+            flags = []
+            for i in range(0, len(datasets)):
+                flags.append(datasets[i].flag)
+            self.flag = np.concatenate(flags)
diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py
index 5a248ef6890ea348ea7ad98154cc163ae1e035c5..6f6a0b514d141889c8a5701214ca963efe6b0a61 100644
--- a/mmdet/datasets/utils.py
+++ b/mmdet/datasets/utils.py
@@ -1,10 +1,14 @@
+import copy
 from collections import Sequence
 
 import mmcv
+from mmcv.runner import obj_from_dict
 import torch
 
 import matplotlib.pyplot as plt
 import numpy as np
+from .concat_dataset import ConcatDataset
+from .. import datasets
 
 
 def to_tensor(data):
@@ -67,3 +71,41 @@ def show_ann(coco, img, ann_info):
     plt.axis('off')
     coco.showAnns(ann_info)
     plt.show()
+
+
+def get_dataset(data_cfg):
+    if isinstance(data_cfg['ann_file'], (list, tuple)):
+        ann_files = data_cfg['ann_file']
+        num_dset = len(ann_files)
+    else:
+        ann_files = [data_cfg['ann_file']]
+        num_dset = 1
+
+    if 'proposal_file' in data_cfg.keys():
+        if isinstance(data_cfg['proposal_file'], (list, tuple)):
+            proposal_files = data_cfg['proposal_file']
+        else:
+            proposal_files = [data_cfg['proposal_file']]
+    else:
+        proposal_files = [None] * num_dset
+    assert len(proposal_files) == num_dset
+
+    if isinstance(data_cfg['img_prefix'], (list, tuple)):
+        img_prefixes = data_cfg['img_prefix']
+    else:
+        img_prefixes = [data_cfg['img_prefix']] * num_dset
+    assert len(img_prefixes) == num_dset
+
+    dsets = []
+    for i in range(num_dset):
+        data_info = copy.deepcopy(data_cfg)
+        data_info['ann_file'] = ann_files[i]
+        data_info['proposal_file'] = proposal_files[i]
+        data_info['img_prefix'] = img_prefixes[i]
+        dset = obj_from_dict(data_info, datasets)
+        dsets.append(dset)
+    if len(dsets) > 1:
+        dset = ConcatDataset(dsets)
+    else:
+        dset = dsets[0]
+    return dset
diff --git a/tools/train.py b/tools/train.py
index 8e03628db5ea28d027ccdc3939c72bace482be93..bd47e66bed121629368a86cd9d809a6cfe2b9363 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -2,9 +2,9 @@ from __future__ import division
 
 import argparse
 from mmcv import Config
-from mmcv.runner import obj_from_dict
 
-from mmdet import datasets, __version__
+from mmdet import __version__
+from mmdet.datasets import get_dataset
 from mmdet.apis import (train_detector, init_dist, get_root_logger,
                         set_random_seed)
 from mmdet.models import build_detector
@@ -67,7 +67,7 @@ def main():
 
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
-    train_dataset = obj_from_dict(cfg.data.train, datasets)
+    train_dataset = get_dataset(cfg.data.train)
     train_detector(
         model,
         train_dataset,