From 63288def2839c3cea45b2ea23eb932b2c192e54e Mon Sep 17 00:00:00 2001
From: yhcao6 <yhcao6@gmail.com>
Date: Thu, 6 Dec 2018 22:16:51 +0800
Subject: [PATCH] support recursion

---
 mmdet/datasets/__init__.py       |  5 +++--
 mmdet/datasets/repeat_dataset.py | 13 ++++++-------
 mmdet/datasets/utils.py          |  7 +------
 3 files changed, 10 insertions(+), 15 deletions(-)

diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
index 3c936c7..9c29482 100644
--- a/mmdet/datasets/__init__.py
+++ b/mmdet/datasets/__init__.py
@@ -7,6 +7,7 @@ from .repeat_dataset import RepeatDataset
 
 __all__ = [
     'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
-    'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
-    'show_ann', 'get_dataset', 'RepeatDataset'
+    'build_dataloader', 'to_tensor', 'random_scale', 'show_ann',
+    'get_dataset', 'ExtraAugmentation', 'ConcatDataset', 'RepeatDataset',
 ]
+
diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py
index 317bdf5..5edf33c 100644
--- a/mmdet/datasets/repeat_dataset.py
+++ b/mmdet/datasets/repeat_dataset.py
@@ -3,16 +3,15 @@ import numpy as np
 
 class RepeatDataset(object):
 
-    def __init__(self, dataset, repeat_times):
+    def __init__(self, dataset, times):
         self.dataset = dataset
-        self.repeat_times = repeat_times
+        self.times = times
         if hasattr(self.dataset, 'flag'):
-            self.flag = np.tile(self.dataset.flag, repeat_times)
-        self.length = len(self.dataset) * self.repeat_times
+            self.flag = np.tile(self.dataset.flag, times)
+        self._original_length = len(self.dataset)
 
     def __getitem__(self, idx):
-        return self.dataset[idx % len(self.dataset)]
+        return self.dataset[idx % self._original_length]
 
     def __len__(self):
-        return self.length
-
+        return self.times * self._original_length
diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py
index e8f3519..0af0db0 100644
--- a/mmdet/datasets/utils.py
+++ b/mmdet/datasets/utils.py
@@ -75,10 +75,8 @@ def show_ann(coco, img, ann_info):
 
 
 def get_dataset(data_cfg):
-    repeat_times = None
     if data_cfg['type'] == 'RepeatDataset':
-        repeat_times = data_cfg['repeat_times']
-        data_cfg = data_cfg['dataset']
+        return RepeatDataset(get_dataset(data_cfg['type']), data_cfg['times'])
 
     if isinstance(data_cfg['ann_file'], (list, tuple)):
         ann_files = data_cfg['ann_file']
@@ -114,7 +112,4 @@ def get_dataset(data_cfg):
         dset = ConcatDataset(dsets)
     else:
         dset = dsets[0]
-
-    if repeat_times is not None:
-        dset = RepeatDataset(dset, repeat_times)
     return dset
-- 
GitLab