From 60d5bc546799ef274670ba5478484fe9af892a7d Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Wed, 12 Dec 2018 00:26:52 +0800 Subject: [PATCH] add eval hooks for VOC dataset --- mmdet/apis/train.py | 12 ++++++--- mmdet/core/evaluation/__init__.py | 4 +-- mmdet/core/evaluation/eval_hooks.py | 39 +++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 2a58972..38bafa4 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -6,8 +6,8 @@ import torch from mmcv.runner import Runner, DistSamplerSeedHook from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook, - CocoDistEvalmAPHook) +from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, + CocoDistEvalRecallHook, CocoDistEvalmAPHook) from mmdet.datasets import build_dataloader from mmdet.models import RPN from .env import get_root_logger @@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False): # register eval hooks if validate: if isinstance(model.module, RPN): + # TODO: implement recall hooks for other datasets runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) - elif cfg.data.val.type == 'CocoDataset': - runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) + else: + if cfg.data.val.type == 'CocoDataset': + runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) + else: + runner.register_hook(DistEvalmAPHook(cfg.data.val)) if cfg.resume_from: runner.resume(cfg.resume_from) diff --git a/mmdet/core/evaluation/__init__.py b/mmdet/core/evaluation/__init__.py index 026234f..4585c23 100644 --- a/mmdet/core/evaluation/__init__.py +++ b/mmdet/core/evaluation/__init__.py @@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes, imagenet_vid_classes, coco_classes, dataset_aliases, get_classes) from .coco_utils import coco_eval, fast_eval_recall, results2json -from .eval_hooks import (DistEvalHook, CocoDistEvalRecallHook, +from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook, CocoDistEvalmAPHook) from .mean_ap import average_precision, eval_map, print_map_summary from .recall import (eval_recalls, print_recall_summary, plot_num_recall, @@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall, __all__ = [ 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes', 'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval', - 'fast_eval_recall', 'results2json', 'DistEvalHook', + 'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook', 'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision', 'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary', 'plot_num_recall', 'plot_iou_recall' diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py index 1402f7f..ba6adbb 100644 --- a/mmdet/core/evaluation/eval_hooks.py +++ b/mmdet/core/evaluation/eval_hooks.py @@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval from torch.utils.data import Dataset from .coco_utils import results2json, fast_eval_recall +from .mean_ap import eval_map from mmdet import datasets @@ -102,6 +103,44 @@ class DistEvalHook(Hook): raise NotImplementedError +class DistEvalmAPHook(DistEvalHook): + + def evaluate(self, runner, results): + gt_bboxes = [] + gt_labels = [] + gt_ignore = [] if self.dataset.with_crowd else None + for i in range(len(self.dataset)): + ann = self.dataset.get_ann_info(i) + bboxes = ann['bboxes'] + labels = ann['labels'] + if gt_ignore is not None: + ignore = np.concatenate([ + np.zeros(bboxes.shape[0], dtype=np.bool), + np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool) + ]) + gt_ignore.append(ignore) + bboxes = np.vstack([bboxes, ann['bboxes_ignore']]) + labels = np.concatenate([labels, ann['labels_ignore']]) + gt_bboxes.append(bboxes) + gt_labels.append(labels) + # If the dataset is VOC2007, then use 11 points mAP evaluation. + if hasattr(self.dataset, 'year') and self.dataset.year == 2007: + ds_name = 'voc07' + else: + ds_name = self.dataset.CLASSES + mean_ap, eval_results = eval_map( + results, + gt_bboxes, + gt_labels, + gt_ignore=gt_ignore, + scale_ranges=None, + iou_thr=0.5, + dataset=ds_name, + print_summary=True) + runner.log_buffer.output['mAP'] = mean_ap + runner.log_buffer.ready = True + + class CocoDistEvalRecallHook(DistEvalHook): def __init__(self, -- GitLab