Skip to content
Snippets Groups Projects
Unverified Commit ab5bca65 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #144 from yhcao6/RepeatDataset

repeat dataset
parents 65a2e5ea 2afb5a2f
No related branches found
No related tags found
No related merge requests found
...@@ -3,9 +3,10 @@ from .coco import CocoDataset ...@@ -3,9 +3,10 @@ from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann',
'show_ann', 'get_dataset' 'get_dataset', 'ExtraAugmentation', 'ConcatDataset', 'RepeatDataset',
] ]
import numpy as np
class RepeatDataset(object):
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)
self._original_length = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._original_length]
def __len__(self):
return self.times * self._original_length
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
from .. import datasets from .. import datasets
...@@ -74,6 +75,10 @@ def show_ann(coco, img, ann_info): ...@@ -74,6 +75,10 @@ def show_ann(coco, img, ann_info):
def get_dataset(data_cfg): def get_dataset(data_cfg):
if data_cfg['type'] == 'RepeatDataset':
return RepeatDataset(
get_dataset(data_cfg['dataset']), data_cfg['times'])
if isinstance(data_cfg['ann_file'], (list, tuple)): if isinstance(data_cfg['ann_file'], (list, tuple)):
ann_files = data_cfg['ann_file'] ann_files = data_cfg['ann_file']
num_dset = len(ann_files) num_dset = len(ann_files)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment