diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index dda1a1eb4a514b93254623693753ecf8f3839c12..2e72022bdade2ae441b05a83810788314f6caf6a 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -3,9 +3,10 @@ from .coco import CocoDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .utils import to_tensor, random_scale, show_ann, get_dataset from .concat_dataset import ConcatDataset +from .repeat_dataset import RepeatDataset __all__ = [ 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', - 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', - 'show_ann', 'get_dataset' + 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', + 'get_dataset', 'ExtraAugmentation', 'ConcatDataset', 'RepeatDataset', ] diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5edf33c7417a917287469e36c9c839d882adf8dd --- /dev/null +++ b/mmdet/datasets/repeat_dataset.py @@ -0,0 +1,17 @@ +import numpy as np + + +class RepeatDataset(object): + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + if hasattr(self.dataset, 'flag'): + self.flag = np.tile(self.dataset.flag, times) + self._original_length = len(self.dataset) + + def __getitem__(self, idx): + return self.dataset[idx % self._original_length] + + def __len__(self): + return self.times * self._original_length diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index 6f6a0b514d141889c8a5701214ca963efe6b0a61..8fdba7f211f26c8dffd95b4ae728a8e11576d305 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -8,6 +8,7 @@ import torch import matplotlib.pyplot as plt import numpy as np from .concat_dataset import ConcatDataset +from .repeat_dataset import RepeatDataset from .. import datasets @@ -74,6 +75,10 @@ def show_ann(coco, img, ann_info): def get_dataset(data_cfg): + if data_cfg['type'] == 'RepeatDataset': + return RepeatDataset( + get_dataset(data_cfg['dataset']), data_cfg['times']) + if isinstance(data_cfg['ann_file'], (list, tuple)): ann_files = data_cfg['ann_file'] num_dset = len(ann_files)