From 7906bd2064d8511063d917d0d066d346d1edcaf2 Mon Sep 17 00:00:00 2001
From: wangg12 <guwang12@gmail.com>
Date: Wed, 28 Nov 2018 11:32:47 +0800
Subject: [PATCH] support training on dataset with multiple ann_files

---
 mmdet/datasets/__init__.py       |  3 ++-
 mmdet/datasets/concat_dataset.py | 30 ++++++++++++++++++++++++++++++
 tools/train.py                   | 22 +++++++++++++++++++++-
 3 files changed, 53 insertions(+), 2 deletions(-)
 create mode 100644 mmdet/datasets/concat_dataset.py

diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
index 75e0709..a849a2a 100644
--- a/mmdet/datasets/__init__.py
+++ b/mmdet/datasets/__init__.py
@@ -2,8 +2,9 @@ from .custom import CustomDataset
 from .coco import CocoDataset
 from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
 from .utils import to_tensor, random_scale, show_ann
+from .concat_dataset import ConcatDataset
 
 __all__ = [
-    'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
+    'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset',
     'build_dataloader', 'to_tensor', 'random_scale', 'show_ann'
 ]
diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py
new file mode 100644
index 0000000..47e073a
--- /dev/null
+++ b/mmdet/datasets/concat_dataset.py
@@ -0,0 +1,30 @@
+import bisect
+import numpy as np
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+
+class ConcatDataset(_ConcatDataset):
+    """
+    Same as torch.utils.data.dataset.ConcatDataset, but 
+    concat the group flag for image aspect ratio.
+    """
+    def __init__(self, datasets):
+        """
+        flag: Images with aspect ratio greater than 1 will be set as group 1,
+              otherwise group 0.
+        """
+        super(ConcatDataset, self).__init__(datasets)        
+        if hasattr(datasets[0], 'flag'):
+            flags = []
+            for i in range(0, len(datasets)):
+                flags.append(datasets[i].flag)
+            self.flag = np.concatenate(flags)
+
+    def get_idxs(self, idx):
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+        return dataset_idx, sample_idx
+
diff --git a/tools/train.py b/tools/train.py
index 8e03628..c1e84b1 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -1,10 +1,12 @@
 from __future__ import division
 
 import argparse
+import copy
 from mmcv import Config
 from mmcv.runner import obj_from_dict
 
 from mmdet import datasets, __version__
+from mmdet.datasets import ConcatDataset
 from mmdet.apis import (train_detector, init_dist, get_root_logger,
                         set_random_seed)
 from mmdet.models import build_detector
@@ -36,6 +38,24 @@ def parse_args():
     return args
 
 
+def get_train_dataset(cfg):
+    if isinstance(cfg.data.train['ann_file'], list) or isinstance(cfg.data.train['ann_file'], tuple):
+        ann_files = cfg.data.train['ann_file']
+        train_datasets = []
+        for ann_file in ann_files:
+            data_info = copy.deepcopy(cfg.data.train)
+            data_info['ann_file'] = ann_file
+            train_dset = obj_from_dict(data_info, datasets)            
+            train_datasets.append(train_dset)
+        if len(train_datasets) > 1:
+            train_dataset = ConcatDataset(train_datasets)            
+        else:
+            train_dataset = train_datasets[0]
+    else:
+        train_dataset = obj_from_dict(cfg.data.train, datasets)
+    return train_dataset
+
+
 def main():
     args = parse_args()
 
@@ -67,7 +87,7 @@ def main():
 
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
-    train_dataset = obj_from_dict(cfg.data.train, datasets)
+    train_dataset = get_train_dataset(cfg)
     train_detector(
         model,
         train_dataset,
-- 
GitLab