From 63288def2839c3cea45b2ea23eb932b2c192e54e Mon Sep 17 00:00:00 2001 From: yhcao6 <yhcao6@gmail.com> Date: Thu, 6 Dec 2018 22:16:51 +0800 Subject: [PATCH] support recursion --- mmdet/datasets/__init__.py | 5 +++-- mmdet/datasets/repeat_dataset.py | 13 ++++++------- mmdet/datasets/utils.py | 7 +------ 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 3c936c7..9c29482 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -7,6 +7,7 @@ from .repeat_dataset import RepeatDataset __all__ = [ 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', - 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', - 'show_ann', 'get_dataset', 'RepeatDataset' + 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', + 'get_dataset', 'ExtraAugmentation', 'ConcatDataset', 'RepeatDataset', ] + diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py index 317bdf5..5edf33c 100644 --- a/mmdet/datasets/repeat_dataset.py +++ b/mmdet/datasets/repeat_dataset.py @@ -3,16 +3,15 @@ import numpy as np class RepeatDataset(object): - def __init__(self, dataset, repeat_times): + def __init__(self, dataset, times): self.dataset = dataset - self.repeat_times = repeat_times + self.times = times if hasattr(self.dataset, 'flag'): - self.flag = np.tile(self.dataset.flag, repeat_times) - self.length = len(self.dataset) * self.repeat_times + self.flag = np.tile(self.dataset.flag, times) + self._original_length = len(self.dataset) def __getitem__(self, idx): - return self.dataset[idx % len(self.dataset)] + return self.dataset[idx % self._original_length] def __len__(self): - return self.length - + return self.times * self._original_length diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index e8f3519..0af0db0 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -75,10 +75,8 @@ def show_ann(coco, img, ann_info): def get_dataset(data_cfg): - repeat_times = None if data_cfg['type'] == 'RepeatDataset': - repeat_times = data_cfg['repeat_times'] - data_cfg = data_cfg['dataset'] + return RepeatDataset(get_dataset(data_cfg['type']), data_cfg['times']) if isinstance(data_cfg['ann_file'], (list, tuple)): ann_files = data_cfg['ann_file'] @@ -114,7 +112,4 @@ def get_dataset(data_cfg): dset = ConcatDataset(dsets) else: dset = dsets[0] - - if repeat_times is not None: - dset = RepeatDataset(dset, repeat_times) return dset -- GitLab