From edb03937964b583a59dd1bddf76eaba82df9e8c0 Mon Sep 17 00:00:00 2001 From: Bo Li <drluodian@gmail.com> Date: Sun, 24 Mar 2019 10:09:01 +0800 Subject: [PATCH] Added mask visualization part to inference part and add out_file interface. (#403) * Update README.md * Update inference.py * Update README.md * Update inference.py Added mask visualization part for inferring. * Update README.md * Update inference.py * Update inference.py convert all tabs to spaces * Update inference.py --- mmdet/apis/inference.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index a39763a..660613b 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -1,10 +1,11 @@ import mmcv import numpy as np +import pycocotools.mask as maskUtils import torch +from mmdet.core import get_classes from mmdet.datasets import to_tensor from mmdet.datasets.transforms import ImageTransform -from mmdet.core import get_classes def _prepare_data(img, img_transform, cfg, device): @@ -50,18 +51,33 @@ def inference_detector(model, imgs, cfg, device='cuda:0'): return _inference_generator(model, imgs, img_transform, cfg, device) -def show_result(img, result, dataset='coco', score_thr=0.3): +def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None): + img = mmcv.imread(img) class_names = get_classes(dataset) + if isinstance(result, tuple): + bbox_result, segm_result = result + else: + bbox_result, segm_result = result, None + bboxes = np.vstack(bbox_result) + # draw segmentation masks + if segm_result is not None: + segms = mmcv.concat_list(segm_result) + inds = np.where(bboxes[:, -1] > score_thr)[0] + for i in inds: + color_mask = np.random.randint( + 0, 256, (1, 3), dtype=np.uint8) + mask = maskUtils.decode(segms[i]).astype(np.bool) + img[mask] = img[mask] * 0.5 + color_mask * 0.5 + # draw bounding boxes labels = [ np.full(bbox.shape[0], i, dtype=np.int32) - for i, bbox in enumerate(result) + for i, bbox in enumerate(bbox_result) ] labels = np.concatenate(labels) - bboxes = np.vstack(result) - img = mmcv.imread(img) mmcv.imshow_det_bboxes( img.copy(), bboxes, labels, class_names=class_names, - score_thr=score_thr) + score_thr=score_thr, + show=out_file is None) -- GitLab