From d6b69bdadab42479fd5478d36821b384aba9e10d Mon Sep 17 00:00:00 2001
From: yhcao6 <yhcao6@gmail.com>
Date: Thu, 6 Dec 2018 21:44:50 +0800
Subject: [PATCH] add RepeatDataset

---
 mmdet/datasets/__init__.py       |  3 ++-
 mmdet/datasets/repeat_dataset.py | 18 ++++++++++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)
 create mode 100644 mmdet/datasets/repeat_dataset.py

diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
index dda1a1e..3c936c7 100644
--- a/mmdet/datasets/__init__.py
+++ b/mmdet/datasets/__init__.py
@@ -3,9 +3,10 @@ from .coco import CocoDataset
 from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
 from .utils import to_tensor, random_scale, show_ann, get_dataset
 from .concat_dataset import ConcatDataset
+from .repeat_dataset import RepeatDataset
 
 __all__ = [
     'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
     'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
-    'show_ann', 'get_dataset'
+    'show_ann', 'get_dataset', 'RepeatDataset'
 ]
diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py
new file mode 100644
index 0000000..317bdf5
--- /dev/null
+++ b/mmdet/datasets/repeat_dataset.py
@@ -0,0 +1,18 @@
+import numpy as np
+
+
+class RepeatDataset(object):
+
+    def __init__(self, dataset, repeat_times):
+        self.dataset = dataset
+        self.repeat_times = repeat_times
+        if hasattr(self.dataset, 'flag'):
+            self.flag = np.tile(self.dataset.flag, repeat_times)
+        self.length = len(self.dataset) * self.repeat_times
+
+    def __getitem__(self, idx):
+        return self.dataset[idx % len(self.dataset)]
+
+    def __len__(self):
+        return self.length
+
-- 
GitLab