Skip to content
Snippets Groups Projects
Commit 830effcd authored by Kai Chen's avatar Kai Chen
Browse files

add default result visualization for base detector

parent 65c3ebca
No related branches found
No related tags found
No related merge requests found
......@@ -95,7 +95,7 @@ def get_classes(dataset):
if mmcv.is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_labels()')
labels = eval(alias2name[dataset] + '_classes()')
else:
raise ValueError('Unrecognized dataset: {}'.format(dataset))
else:
......
import logging
from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmdet.core import tensor2imgs, get_classes
class BaseDetector(nn.Module):
"""Base class for detectors"""
......@@ -66,3 +70,38 @@ class BaseDetector(nn.Module):
return self.forward_train(img, img_meta, **kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
def show_result(self,
data,
result,
img_norm_cfg,
dataset='coco',
score_thr=0.3):
img_tensor = data['img'][0]
img_metas = data['img_meta'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
if isinstance(dataset, str):
class_names = get_classes(dataset)
elif isinstance(dataset, list):
class_names = dataset
else:
raise TypeError('dataset must be a valid dataset name or a list'
' of class names, not {}'.format(type(dataset)))
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
mmcv.imshow_det_bboxes(
img_show,
bboxes,
labels,
class_names=class_names,
score_thr=score_thr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment