From edb03937964b583a59dd1bddf76eaba82df9e8c0 Mon Sep 17 00:00:00 2001
From: Bo Li <drluodian@gmail.com>
Date: Sun, 24 Mar 2019 10:09:01 +0800
Subject: [PATCH] Added mask visualization part to inference part and add
 out_file interface. (#403)

* Update README.md

* Update inference.py

* Update README.md

* Update inference.py

Added mask visualization part for inferring.

* Update README.md

* Update inference.py

* Update inference.py

convert all tabs to spaces

* Update inference.py
---
 mmdet/apis/inference.py | 28 ++++++++++++++++++++++------
 1 file changed, 22 insertions(+), 6 deletions(-)

diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index a39763a..660613b 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -1,10 +1,11 @@
 import mmcv
 import numpy as np
+import pycocotools.mask as maskUtils
 import torch
 
+from mmdet.core import get_classes
 from mmdet.datasets import to_tensor
 from mmdet.datasets.transforms import ImageTransform
-from mmdet.core import get_classes
 
 
 def _prepare_data(img, img_transform, cfg, device):
@@ -50,18 +51,33 @@ def inference_detector(model, imgs, cfg, device='cuda:0'):
         return _inference_generator(model, imgs, img_transform, cfg, device)
 
 
-def show_result(img, result, dataset='coco', score_thr=0.3):
+def show_result(img, result, dataset='coco', score_thr=0.3, out_file=None):
+    img = mmcv.imread(img)
     class_names = get_classes(dataset)
+    if isinstance(result, tuple):
+        bbox_result, segm_result = result
+    else:
+        bbox_result, segm_result = result, None
+    bboxes = np.vstack(bbox_result)
+    # draw segmentation masks
+    if segm_result is not None:
+        segms = mmcv.concat_list(segm_result)
+        inds = np.where(bboxes[:, -1] > score_thr)[0]
+        for i in inds:
+            color_mask = np.random.randint(
+                0, 256, (1, 3), dtype=np.uint8)
+            mask = maskUtils.decode(segms[i]).astype(np.bool)
+            img[mask] = img[mask] * 0.5 + color_mask * 0.5
+    # draw bounding boxes
     labels = [
         np.full(bbox.shape[0], i, dtype=np.int32)
-        for i, bbox in enumerate(result)
+        for i, bbox in enumerate(bbox_result)
     ]
     labels = np.concatenate(labels)
-    bboxes = np.vstack(result)
-    img = mmcv.imread(img)
     mmcv.imshow_det_bboxes(
         img.copy(),
         bboxes,
         labels,
         class_names=class_names,
-        score_thr=score_thr)
+        score_thr=score_thr,
+        show=out_file is None)
-- 
GitLab