From ebc831226cfc5694b56ffec5e07527d37346fb3a Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Sun, 28 Apr 2019 17:06:25 -0700 Subject: [PATCH] check dataset type by subclass instead of names --- mmdet/apis/train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 38bafa4..152aad8 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -6,6 +6,7 @@ import torch from mmcv.runner import Runner, DistSamplerSeedHook from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmdet import datasets from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, CocoDistEvalRecallHook, CocoDistEvalmAPHook) from mmdet.datasets import build_dataloader @@ -80,14 +81,16 @@ def _dist_train(model, dataset, cfg, validate=False): runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: + val_dataset_cfg = cfg.data.val if isinstance(model.module, RPN): # TODO: implement recall hooks for other datasets - runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) + runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg)) else: - if cfg.data.val.type == 'CocoDataset': - runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) + dataset_type = getattr(datasets, val_dataset_cfg.type) + if issubclass(dataset_type, datasets.CocoDataset): + runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg)) else: - runner.register_hook(DistEvalmAPHook(cfg.data.val)) + runner.register_hook(DistEvalmAPHook(val_dataset_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from) -- GitLab