From c829f05471190341ef65a92c3bf000d5c26e2f62 Mon Sep 17 00:00:00 2001 From: Demetris Marnerides <dmarnerides@gmail.com> Date: Fri, 26 Jul 2019 08:24:00 +0100 Subject: [PATCH] Added Registry use for validation datasets (distributed) (#1058) * Added Registry use for validation datasets (distributed) * Allowing for default_args for build_dataset * Using build_dataset instead of build_from_cfg --- mmdet/apis/train.py | 4 ++-- mmdet/core/evaluation/eval_hooks.py | 5 ++--- mmdet/datasets/builder.py | 13 +++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 60d6c7d..0ebbb9e 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -11,7 +11,7 @@ from mmdet import datasets from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, CocoDistEvalRecallHook, CocoDistEvalmAPHook, Fp16OptimizerHook) -from mmdet.datasets import build_dataloader +from mmdet.datasets import build_dataloader, DATASETS from mmdet.models import RPN from .env import get_root_logger @@ -174,7 +174,7 @@ def _dist_train(model, dataset, cfg, validate=False): runner.register_hook( CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg)) else: - dataset_type = getattr(datasets, val_dataset_cfg.type) + dataset_type = DATASETS.get(val_dataset_cfg.type) if issubclass(dataset_type, datasets.CocoDataset): runner.register_hook( CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg)) diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index 99bd2d8..1786d93 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -5,7 +5,7 @@ import mmcv import numpy as np import torch import torch.distributed as dist -from mmcv.runner import Hook, obj_from_dict +from mmcv.runner import Hook from mmcv.parallel import scatter, collate from pycocotools.cocoeval import COCOeval from torch.utils.data import Dataset @@ -21,8 +21,7 @@ class DistEvalHook(Hook): if isinstance(dataset, Dataset): self.dataset = dataset elif isinstance(dataset, dict): - self.dataset = obj_from_dict(dataset, datasets, - {'test_mode': True}) + self.dataset = datasets.build_dataset(dataset, {'test_mode': True}) else: raise TypeError( 'dataset must be a Dataset object or a dict, not {}'.format( diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py index 6b1ffba..3d4e2ca 100644 --- a/mmdet/datasets/builder.py +++ b/mmdet/datasets/builder.py @@ -5,7 +5,7 @@ from .dataset_wrappers import ConcatDataset, RepeatDataset from .registry import DATASETS -def _concat_dataset(cfg): +def _concat_dataset(cfg, default_args=None): ann_files = cfg['ann_file'] img_prefixes = cfg.get('img_prefix', None) seg_prefixes = cfg.get('seg_prefixes', None) @@ -22,17 +22,18 @@ def _concat_dataset(cfg): data_cfg['seg_prefix'] = seg_prefixes[i] if isinstance(proposal_files, (list, tuple)): data_cfg['proposal_file'] = proposal_files[i] - datasets.append(build_dataset(data_cfg)) + datasets.append(build_dataset(data_cfg, default_args)) return ConcatDataset(datasets) -def build_dataset(cfg): +def build_dataset(cfg, default_args=None): if cfg['type'] == 'RepeatDataset': - dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times']) + dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args), + cfg['times']) elif isinstance(cfg['ann_file'], (list, tuple)): - dataset = _concat_dataset(cfg) + dataset = _concat_dataset(cfg, default_args) else: - dataset = build_from_cfg(cfg, DATASETS) + dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset -- GitLab