diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index dda1a1eb4a514b93254623693753ecf8f3839c12..3c936c7251c3484c5deaa488b3e120dbfe0bc147 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -3,9 +3,10 @@ from .coco import CocoDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .utils import to_tensor, random_scale, show_ann, get_dataset from .concat_dataset import ConcatDataset +from .repeat_dataset import RepeatDataset __all__ = [ 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', - 'show_ann', 'get_dataset' + 'show_ann', 'get_dataset', 'RepeatDataset' ] diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..317bdf544b0fd3b7850cfe977ac5092df8594e78 --- /dev/null +++ b/mmdet/datasets/repeat_dataset.py @@ -0,0 +1,18 @@ +import numpy as np + + +class RepeatDataset(object): + + def __init__(self, dataset, repeat_times): + self.dataset = dataset + self.repeat_times = repeat_times + if hasattr(self.dataset, 'flag'): + self.flag = np.tile(self.dataset.flag, repeat_times) + self.length = len(self.dataset) * self.repeat_times + + def __getitem__(self, idx): + return self.dataset[idx % len(self.dataset)] + + def __len__(self): + return self.length +