Skip to content
Snippets Groups Projects
Commit d6b69bda authored by yhcao6's avatar yhcao6
Browse files

add RepeatDataset

parent a6ee0532
No related branches found
No related tags found
No related merge requests found
...@@ -3,9 +3,10 @@ from .coco import CocoDataset ...@@ -3,9 +3,10 @@ from .coco import CocoDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale', 'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset' 'show_ann', 'get_dataset', 'RepeatDataset'
] ]
import numpy as np
class RepeatDataset(object):
def __init__(self, dataset, repeat_times):
self.dataset = dataset
self.repeat_times = repeat_times
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, repeat_times)
self.length = len(self.dataset) * self.repeat_times
def __getitem__(self, idx):
return self.dataset[idx % len(self.dataset)]
def __len__(self):
return self.length
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment