diff --git a/mmdet/core/eval/class_names.py b/mmdet/core/eval/class_names.py index b68e9135dca366e93217e0c06959bea990ffda5e..04f806315b7c6ef47419efa61e38d2f7ec3ebd2a 100644 --- a/mmdet/core/eval/class_names.py +++ b/mmdet/core/eval/class_names.py @@ -95,7 +95,7 @@ def get_classes(dataset): if mmcv.is_str(dataset): if dataset in alias2name: - labels = eval(alias2name[dataset] + '_labels()') + labels = eval(alias2name[dataset] + '_classes()') else: raise ValueError('Unrecognized dataset: {}'.format(dataset)) else: diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 3b2040312ee08338e4606c2f154a399c048619c7..93a05c8594eb70e34c9291117f32df42b408bd40 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -1,9 +1,13 @@ import logging from abc import ABCMeta, abstractmethod +import mmcv +import numpy as np import torch import torch.nn as nn +from mmdet.core import tensor2imgs, get_classes + class BaseDetector(nn.Module): """Base class for detectors""" @@ -66,3 +70,38 @@ class BaseDetector(nn.Module): return self.forward_train(img, img_meta, **kwargs) else: return self.forward_test(img, img_meta, **kwargs) + + def show_result(self, + data, + result, + img_norm_cfg, + dataset='coco', + score_thr=0.3): + img_tensor = data['img'][0] + img_metas = data['img_meta'][0].data[0] + imgs = tensor2imgs(img_tensor, **img_norm_cfg) + assert len(imgs) == len(img_metas) + + if isinstance(dataset, str): + class_names = get_classes(dataset) + elif isinstance(dataset, list): + class_names = dataset + else: + raise TypeError('dataset must be a valid dataset name or a list' + ' of class names, not {}'.format(type(dataset))) + + for img, img_meta in zip(imgs, img_metas): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(result) + ] + labels = np.concatenate(labels) + bboxes = np.vstack(result) + mmcv.imshow_det_bboxes( + img_show, + bboxes, + labels, + class_names=class_names, + score_thr=score_thr)