From 60d5bc546799ef274670ba5478484fe9af892a7d Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Wed, 12 Dec 2018 00:26:52 +0800
Subject: [PATCH] add eval hooks for VOC dataset

---
 mmdet/apis/train.py                 | 12 ++++++---
 mmdet/core/evaluation/__init__.py   |  4 +--
 mmdet/core/evaluation/eval_hooks.py | 39 +++++++++++++++++++++++++++++
 3 files changed, 49 insertions(+), 6 deletions(-)

diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 2a58972..38bafa4 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -6,8 +6,8 @@ import torch
 from mmcv.runner import Runner, DistSamplerSeedHook
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
 
-from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
-                        CocoDistEvalmAPHook)
+from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
+                        CocoDistEvalRecallHook, CocoDistEvalmAPHook)
 from mmdet.datasets import build_dataloader
 from mmdet.models import RPN
 from .env import get_root_logger
@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
     # register eval hooks
     if validate:
         if isinstance(model.module, RPN):
+            # TODO: implement recall hooks for other datasets
             runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
-        elif cfg.data.val.type == 'CocoDataset':
-            runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
+        else:
+            if cfg.data.val.type == 'CocoDataset':
+                runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
+            else:
+                runner.register_hook(DistEvalmAPHook(cfg.data.val))
 
     if cfg.resume_from:
         runner.resume(cfg.resume_from)
diff --git a/mmdet/core/evaluation/__init__.py b/mmdet/core/evaluation/__init__.py
index 026234f..4585c23 100644
--- a/mmdet/core/evaluation/__init__.py
+++ b/mmdet/core/evaluation/__init__.py
@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
                           imagenet_vid_classes, coco_classes, dataset_aliases,
                           get_classes)
 from .coco_utils import coco_eval, fast_eval_recall, results2json
-from .eval_hooks import (DistEvalHook, CocoDistEvalRecallHook,
+from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook,
                          CocoDistEvalmAPHook)
 from .mean_ap import average_precision, eval_map, print_map_summary
 from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
 __all__ = [
     'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
     'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval',
-    'fast_eval_recall', 'results2json', 'DistEvalHook',
+    'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook',
     'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision',
     'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
     'plot_num_recall', 'plot_iou_recall'
diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py
index 1402f7f..ba6adbb 100644
--- a/mmdet/core/evaluation/eval_hooks.py
+++ b/mmdet/core/evaluation/eval_hooks.py
@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
 from torch.utils.data import Dataset
 
 from .coco_utils import results2json, fast_eval_recall
+from .mean_ap import eval_map
 from mmdet import datasets
 
 
@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
         raise NotImplementedError
 
 
+class DistEvalmAPHook(DistEvalHook):
+
+    def evaluate(self, runner, results):
+        gt_bboxes = []
+        gt_labels = []
+        gt_ignore = [] if self.dataset.with_crowd else None
+        for i in range(len(self.dataset)):
+            ann = self.dataset.get_ann_info(i)
+            bboxes = ann['bboxes']
+            labels = ann['labels']
+            if gt_ignore is not None:
+                ignore = np.concatenate([
+                    np.zeros(bboxes.shape[0], dtype=np.bool),
+                    np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
+                ])
+                gt_ignore.append(ignore)
+                bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
+                labels = np.concatenate([labels, ann['labels_ignore']])
+            gt_bboxes.append(bboxes)
+            gt_labels.append(labels)
+        # If the dataset is VOC2007, then use 11 points mAP evaluation.
+        if hasattr(self.dataset, 'year') and self.dataset.year == 2007:
+            ds_name = 'voc07'
+        else:
+            ds_name = self.dataset.CLASSES
+        mean_ap, eval_results = eval_map(
+            results,
+            gt_bboxes,
+            gt_labels,
+            gt_ignore=gt_ignore,
+            scale_ranges=None,
+            iou_thr=0.5,
+            dataset=ds_name,
+            print_summary=True)
+        runner.log_buffer.output['mAP'] = mean_ap
+        runner.log_buffer.ready = True
+
+
 class CocoDistEvalRecallHook(DistEvalHook):
 
     def __init__(self,
-- 
GitLab