diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 3c936c7251c3484c5deaa488b3e120dbfe0bc147..9c2948286fd5ec369c93ab18b42ebba7520ae470 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -7,6 +7,7 @@ from .repeat_dataset import RepeatDataset __all__ = [ 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', - 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', - 'show_ann', 'get_dataset', 'RepeatDataset' + 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', + 'get_dataset', 'ExtraAugmentation', 'ConcatDataset', 'RepeatDataset', ] + diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py index 317bdf544b0fd3b7850cfe977ac5092df8594e78..5edf33c7417a917287469e36c9c839d882adf8dd 100644 --- a/mmdet/datasets/repeat_dataset.py +++ b/mmdet/datasets/repeat_dataset.py @@ -3,16 +3,15 @@ import numpy as np class RepeatDataset(object): - def __init__(self, dataset, repeat_times): + def __init__(self, dataset, times): self.dataset = dataset - self.repeat_times = repeat_times + self.times = times if hasattr(self.dataset, 'flag'): - self.flag = np.tile(self.dataset.flag, repeat_times) - self.length = len(self.dataset) * self.repeat_times + self.flag = np.tile(self.dataset.flag, times) + self._original_length = len(self.dataset) def __getitem__(self, idx): - return self.dataset[idx % len(self.dataset)] + return self.dataset[idx % self._original_length] def __len__(self): - return self.length - + return self.times * self._original_length diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index e8f351926ad7d2e6dea897ee94d96fcc46f39443..0af0db01582dd461f57b7798cc6afbbba661ca2e 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -75,10 +75,8 @@ def show_ann(coco, img, ann_info): def get_dataset(data_cfg): - repeat_times = None if data_cfg['type'] == 'RepeatDataset': - repeat_times = data_cfg['repeat_times'] - data_cfg = data_cfg['dataset'] + return RepeatDataset(get_dataset(data_cfg['type']), data_cfg['times']) if isinstance(data_cfg['ann_file'], (list, tuple)): ann_files = data_cfg['ann_file'] @@ -114,7 +112,4 @@ def get_dataset(data_cfg): dset = ConcatDataset(dsets) else: dset = dsets[0] - - if repeat_times is not None: - dset = RepeatDataset(dset, repeat_times) return dset