diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 38bafa4475646cb640fbd62338b13b144e93b378..152aad838b7c0ad851a353eaa5c5af1145f7b745 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -6,6 +6,7 @@ import torch
 from mmcv.runner import Runner, DistSamplerSeedHook
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
 
+from mmdet import datasets
 from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
                         CocoDistEvalRecallHook, CocoDistEvalmAPHook)
 from mmdet.datasets import build_dataloader
@@ -80,14 +81,16 @@ def _dist_train(model, dataset, cfg, validate=False):
     runner.register_hook(DistSamplerSeedHook())
     # register eval hooks
     if validate:
+        val_dataset_cfg = cfg.data.val
         if isinstance(model.module, RPN):
             # TODO: implement recall hooks for other datasets
-            runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
+            runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
         else:
-            if cfg.data.val.type == 'CocoDataset':
-                runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
+            dataset_type = getattr(datasets, val_dataset_cfg.type)
+            if issubclass(dataset_type, datasets.CocoDataset):
+                runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
             else:
-                runner.register_hook(DistEvalmAPHook(cfg.data.val))
+                runner.register_hook(DistEvalmAPHook(val_dataset_cfg))
 
     if cfg.resume_from:
         runner.resume(cfg.resume_from)