From ebf59c895855ec15ca1973bc1d617288d5ff2a41 Mon Sep 17 00:00:00 2001 From: Karol Majek <karolmajek@gmail.com> Date: Fri, 27 Dec 2019 13:08:53 +0100 Subject: [PATCH] Per class colors in Mask results viz (#1834) * per class color in mask rcnn * fix max label id * pass flake8 * correct the annotation Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Co-authored-by: Kai Chen <chenkaidev@gmail.com> --- mmdet/apis/inference.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 6724e85..a951501 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -148,20 +148,26 @@ def show_result(img, else: bbox_result, segm_result = result, None bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) # draw segmentation masks if segm_result is not None: segms = mmcv.concat_list(segm_result) inds = np.where(bboxes[:, -1] > score_thr)[0] + np.random.seed(42) + color_masks = [ + np.random.randint(0, 256, (1, 3), dtype=np.uint8) + for _ in range(max(labels) + 1) + ] for i in inds: - color_mask = np.random.randint(0, 256, (1, 3), dtype=np.uint8) + i = int(i) + color_mask = color_masks[labels[i]] mask = maskUtils.decode(segms[i]).astype(np.bool) img[mask] = img[mask] * 0.5 + color_mask * 0.5 # draw bounding boxes - labels = [ - np.full(bbox.shape[0], i, dtype=np.int32) - for i, bbox in enumerate(bbox_result) - ] - labels = np.concatenate(labels) mmcv.imshow_det_bboxes( img, bboxes, -- GitLab