Skip to content
Snippets Groups Projects
Commit 60d5bc54 authored by Kai Chen's avatar Kai Chen
Browse files

add eval hooks for VOC dataset

parent 9d38a278
No related branches found
No related tags found
No related merge requests found
......@@ -6,8 +6,8 @@ import torch
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
from mmdet.models import RPN
from .env import get_root_logger
......@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
if validate:
if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
else:
if cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
else:
runner.register_hook(DistEvalmAPHook(cfg.data.val))
if cfg.resume_from:
runner.resume(cfg.resume_from)
......
......@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes, coco_classes, dataset_aliases,
get_classes)
from .coco_utils import coco_eval, fast_eval_recall, results2json
from .eval_hooks import (DistEvalHook, CocoDistEvalRecallHook,
from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
......@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__ = [
'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval',
'fast_eval_recall', 'results2json', 'DistEvalHook',
'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook',
'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision',
'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall'
......
......@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset
from .coco_utils import results2json, fast_eval_recall
from .mean_ap import eval_map
from mmdet import datasets
......@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise NotImplementedError
class DistEvalmAPHook(DistEvalHook):
def evaluate(self, runner, results):
gt_bboxes = []
gt_labels = []
gt_ignore = [] if self.dataset.with_crowd else None
for i in range(len(self.dataset)):
ann = self.dataset.get_ann_info(i)
bboxes = ann['bboxes']
labels = ann['labels']
if gt_ignore is not None:
ignore = np.concatenate([
np.zeros(bboxes.shape[0], dtype=np.bool),
np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
])
gt_ignore.append(ignore)
bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
labels = np.concatenate([labels, ann['labels_ignore']])
gt_bboxes.append(bboxes)
gt_labels.append(labels)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if hasattr(self.dataset, 'year') and self.dataset.year == 2007:
ds_name = 'voc07'
else:
ds_name = self.dataset.CLASSES
mean_ap, eval_results = eval_map(
results,
gt_bboxes,
gt_labels,
gt_ignore=gt_ignore,
scale_ranges=None,
iou_thr=0.5,
dataset=ds_name,
print_summary=True)
runner.log_buffer.output['mAP'] = mean_ap
runner.log_buffer.ready = True
class CocoDistEvalRecallHook(DistEvalHook):
def __init__(self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment