diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 60d6c7d63be01ee4ee39ad9bacc926cdb6024957..0ebbb9e110fb1c246916ba7c5efbdef96d9c7025 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -11,7 +11,7 @@ from mmdet import datasets from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, CocoDistEvalRecallHook, CocoDistEvalmAPHook, Fp16OptimizerHook) -from mmdet.datasets import build_dataloader +from mmdet.datasets import build_dataloader, DATASETS from mmdet.models import RPN from .env import get_root_logger @@ -174,7 +174,7 @@ def _dist_train(model, dataset, cfg, validate=False): runner.register_hook( CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg)) else: - dataset_type = getattr(datasets, val_dataset_cfg.type) + dataset_type = DATASETS.get(val_dataset_cfg.type) if issubclass(dataset_type, datasets.CocoDataset): runner.register_hook( CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg)) diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index 99bd2d8d6ccd2b57cc6d0c76ec4f535dddf01377..1786d9362410bc3e535f436bdd7777746dab141c 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -5,7 +5,7 @@ import mmcv import numpy as np import torch import torch.distributed as dist -from mmcv.runner import Hook, obj_from_dict +from mmcv.runner import Hook from mmcv.parallel import scatter, collate from pycocotools.cocoeval import COCOeval from torch.utils.data import Dataset @@ -21,8 +21,7 @@ class DistEvalHook(Hook): if isinstance(dataset, Dataset): self.dataset = dataset elif isinstance(dataset, dict): - self.dataset = obj_from_dict(dataset, datasets, - {'test_mode': True}) + self.dataset = datasets.build_dataset(dataset, {'test_mode': True}) else: raise TypeError( 'dataset must be a Dataset object or a dict, not {}'.format( diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py index 6b1ffbaf96c990df74273139657c3395faf05749..3d4e2caafe38a5831620f93ac60399869cafa1d0 100644 --- a/mmdet/datasets/builder.py +++ b/mmdet/datasets/builder.py @@ -5,7 +5,7 @@ from .dataset_wrappers import ConcatDataset, RepeatDataset from .registry import DATASETS -def _concat_dataset(cfg): +def _concat_dataset(cfg, default_args=None): ann_files = cfg['ann_file'] img_prefixes = cfg.get('img_prefix', None) seg_prefixes = cfg.get('seg_prefixes', None) @@ -22,17 +22,18 @@ def _concat_dataset(cfg): data_cfg['seg_prefix'] = seg_prefixes[i] if isinstance(proposal_files, (list, tuple)): data_cfg['proposal_file'] = proposal_files[i] - datasets.append(build_dataset(data_cfg)) + datasets.append(build_dataset(data_cfg, default_args)) return ConcatDataset(datasets) -def build_dataset(cfg): +def build_dataset(cfg, default_args=None): if cfg['type'] == 'RepeatDataset': - dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times']) + dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args), + cfg['times']) elif isinstance(cfg['ann_file'], (list, tuple)): - dataset = _concat_dataset(cfg) + dataset = _concat_dataset(cfg, default_args) else: - dataset = build_from_cfg(cfg, DATASETS) + dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset