Skip to content
Snippets Groups Projects
Unverified Commit 51904cbc authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #559 from hellock/dev

Check dataset type by subclass instead of names
parents 86187a20 ebc83122
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from mmcv.runner import Runner, DistSamplerSeedHook from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet import datasets
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook) CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader from mmdet.datasets import build_dataloader
...@@ -80,14 +81,16 @@ def _dist_train(model, dataset, cfg, validate=False): ...@@ -80,14 +81,16 @@ def _dist_train(model, dataset, cfg, validate=False):
runner.register_hook(DistSamplerSeedHook()) runner.register_hook(DistSamplerSeedHook())
# register eval hooks # register eval hooks
if validate: if validate:
val_dataset_cfg = cfg.data.val
if isinstance(model.module, RPN): if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets # TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
else: else:
if cfg.data.val.type == 'CocoDataset': dataset_type = getattr(datasets, val_dataset_cfg.type)
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
else: else:
runner.register_hook(DistEvalmAPHook(cfg.data.val)) runner.register_hook(DistEvalmAPHook(val_dataset_cfg))
if cfg.resume_from: if cfg.resume_from:
runner.resume(cfg.resume_from) runner.resume(cfg.resume_from)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment