diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
index a849a2a54c7cc3f9599bbe9ee6a7ed0861208343..b478acffebb30d475042259a9b97c8307627b5ec 100644
--- a/mmdet/datasets/__init__.py
+++ b/mmdet/datasets/__init__.py
@@ -1,10 +1,10 @@
 from .custom import CustomDataset
 from .coco import CocoDataset
 from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
-from .utils import to_tensor, random_scale, show_ann
+from .utils import to_tensor, random_scale, show_ann, get_dataset
 from .concat_dataset import ConcatDataset
 
 __all__ = [
     'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset',
-    'build_dataloader', 'to_tensor', 'random_scale', 'show_ann'
+    'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset'
 ]
diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py
index 5a248ef6890ea348ea7ad98154cc163ae1e035c5..d13ae892d485946f122cc33516cfbf78367f7582 100644
--- a/mmdet/datasets/utils.py
+++ b/mmdet/datasets/utils.py
@@ -1,11 +1,13 @@
 from collections import Sequence
-
+import copy
 import mmcv
+from mmcv.runner import obj_from_dict
 import torch
 
 import matplotlib.pyplot as plt
 import numpy as np
-
+from .concat_dataset import ConcatDataset
+from .. import datasets
 
 def to_tensor(data):
     """Convert objects of various python types to :obj:`torch.Tensor`.
@@ -67,3 +69,21 @@ def show_ann(coco, img, ann_info):
     plt.axis('off')
     coco.showAnns(ann_info)
     plt.show()
+
+
+def get_dataset(data_cfg):
+    if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple):
+        ann_files = data_cfg['ann_file']
+        dsets = []
+        for ann_file in ann_files:
+            data_info = copy.deepcopy(data_cfg)
+            data_info['ann_file'] = ann_file
+            dset = obj_from_dict(data_info, datasets)
+            dsets.append(dset)
+        if len(dsets) > 1:
+            dset = ConcatDataset(dsets)            
+        else:
+            dset = dsets[0]
+    else:
+        dset = obj_from_dict(data_cfg, datasets)
+    return dset
\ No newline at end of file
diff --git a/tools/train.py b/tools/train.py
index c1e84b10efc1b585a39dbfdf9fdac11e863453e0..006fc1197f5ea4e1df866c05d8323070b08d51df 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -1,12 +1,10 @@
 from __future__ import division
 
 import argparse
-import copy
 from mmcv import Config
 from mmcv.runner import obj_from_dict
 
 from mmdet import datasets, __version__
-from mmdet.datasets import ConcatDataset
 from mmdet.apis import (train_detector, init_dist, get_root_logger,
                         set_random_seed)
 from mmdet.models import build_detector
@@ -38,24 +36,6 @@ def parse_args():
     return args
 
 
-def get_train_dataset(cfg):
-    if isinstance(cfg.data.train['ann_file'], list) or isinstance(cfg.data.train['ann_file'], tuple):
-        ann_files = cfg.data.train['ann_file']
-        train_datasets = []
-        for ann_file in ann_files:
-            data_info = copy.deepcopy(cfg.data.train)
-            data_info['ann_file'] = ann_file
-            train_dset = obj_from_dict(data_info, datasets)            
-            train_datasets.append(train_dset)
-        if len(train_datasets) > 1:
-            train_dataset = ConcatDataset(train_datasets)            
-        else:
-            train_dataset = train_datasets[0]
-    else:
-        train_dataset = obj_from_dict(cfg.data.train, datasets)
-    return train_dataset
-
-
 def main():
     args = parse_args()
 
@@ -87,7 +67,7 @@ def main():
 
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
-    train_dataset = get_train_dataset(cfg)
+    train_dataset = datasets.get_dataset(cfg.data.train)
     train_detector(
         model,
         train_dataset,