import copy from mmdet.utils import build_from_cfg from .dataset_wrappers import ConcatDataset, RepeatDataset from .registry import DATASETS def _concat_dataset(cfg, default_args=None): ann_files = cfg['ann_file'] img_prefixes = cfg.get('img_prefix', None) seg_prefixes = cfg.get('seg_prefix', None) proposal_files = cfg.get('proposal_file', None) datasets = [] num_dset = len(ann_files) for i in range(num_dset): data_cfg = copy.deepcopy(cfg) data_cfg['ann_file'] = ann_files[i] if isinstance(img_prefixes, (list, tuple)): data_cfg['img_prefix'] = img_prefixes[i] if isinstance(seg_prefixes, (list, tuple)): data_cfg['seg_prefix'] = seg_prefixes[i] if isinstance(proposal_files, (list, tuple)): data_cfg['proposal_file'] = proposal_files[i] datasets.append(build_dataset(data_cfg, default_args)) return ConcatDataset(datasets) def build_dataset(cfg, default_args=None): if isinstance(cfg, (list, tuple)): dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset( build_dataset(cfg['dataset'], default_args), cfg['times']) elif isinstance(cfg['ann_file'], (list, tuple)): dataset = _concat_dataset(cfg, default_args) else: dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset