From ebf59c895855ec15ca1973bc1d617288d5ff2a41 Mon Sep 17 00:00:00 2001
From: Karol Majek <karolmajek@gmail.com>
Date: Fri, 27 Dec 2019 13:08:53 +0100
Subject: [PATCH] Per class colors in Mask results viz (#1834)

* per class color in mask rcnn

* fix max label id

* pass flake8

* correct the annotation

Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
---
 mmdet/apis/inference.py | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index 6724e85..a951501 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -148,20 +148,26 @@ def show_result(img,
     else:
         bbox_result, segm_result = result, None
     bboxes = np.vstack(bbox_result)
+    labels = [
+        np.full(bbox.shape[0], i, dtype=np.int32)
+        for i, bbox in enumerate(bbox_result)
+    ]
+    labels = np.concatenate(labels)
     # draw segmentation masks
     if segm_result is not None:
         segms = mmcv.concat_list(segm_result)
         inds = np.where(bboxes[:, -1] > score_thr)[0]
+        np.random.seed(42)
+        color_masks = [
+            np.random.randint(0, 256, (1, 3), dtype=np.uint8)
+            for _ in range(max(labels) + 1)
+        ]
         for i in inds:
-            color_mask = np.random.randint(0, 256, (1, 3), dtype=np.uint8)
+            i = int(i)
+            color_mask = color_masks[labels[i]]
             mask = maskUtils.decode(segms[i]).astype(np.bool)
             img[mask] = img[mask] * 0.5 + color_mask * 0.5
     # draw bounding boxes
-    labels = [
-        np.full(bbox.shape[0], i, dtype=np.int32)
-        for i, bbox in enumerate(bbox_result)
-    ]
-    labels = np.concatenate(labels)
     mmcv.imshow_det_bboxes(
         img,
         bboxes,
-- 
GitLab