diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 38bafa4475646cb640fbd62338b13b144e93b378..152aad838b7c0ad851a353eaa5c5af1145f7b745 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -6,6 +6,7 @@ import torch from mmcv.runner import Runner, DistSamplerSeedHook from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmdet import datasets from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, CocoDistEvalRecallHook, CocoDistEvalmAPHook) from mmdet.datasets import build_dataloader @@ -80,14 +81,16 @@ def _dist_train(model, dataset, cfg, validate=False): runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: + val_dataset_cfg = cfg.data.val if isinstance(model.module, RPN): # TODO: implement recall hooks for other datasets - runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) + runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg)) else: - if cfg.data.val.type == 'CocoDataset': - runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) + dataset_type = getattr(datasets, val_dataset_cfg.type) + if issubclass(dataset_type, datasets.CocoDataset): + runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg)) else: - runner.register_hook(DistEvalmAPHook(cfg.data.val)) + runner.register_hook(DistEvalmAPHook(val_dataset_cfg)) if cfg.resume_from: runner.resume(cfg.resume_from)