From d6b69bdadab42479fd5478d36821b384aba9e10d Mon Sep 17 00:00:00 2001 From: yhcao6 <yhcao6@gmail.com> Date: Thu, 6 Dec 2018 21:44:50 +0800 Subject: [PATCH] add RepeatDataset --- mmdet/datasets/__init__.py | 3 ++- mmdet/datasets/repeat_dataset.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 mmdet/datasets/repeat_dataset.py diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index dda1a1e..3c936c7 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -3,9 +3,10 @@ from .coco import CocoDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .utils import to_tensor, random_scale, show_ann, get_dataset from .concat_dataset import ConcatDataset +from .repeat_dataset import RepeatDataset __all__ = [ 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', - 'show_ann', 'get_dataset' + 'show_ann', 'get_dataset', 'RepeatDataset' ] diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py new file mode 100644 index 0000000..317bdf5 --- /dev/null +++ b/mmdet/datasets/repeat_dataset.py @@ -0,0 +1,18 @@ +import numpy as np + + +class RepeatDataset(object): + + def __init__(self, dataset, repeat_times): + self.dataset = dataset + self.repeat_times = repeat_times + if hasattr(self.dataset, 'flag'): + self.flag = np.tile(self.dataset.flag, repeat_times) + self.length = len(self.dataset) * self.repeat_times + + def __getitem__(self, idx): + return self.dataset[idx % len(self.dataset)] + + def __len__(self): + return self.length + -- GitLab