diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 6724e85c85009366cef8c070a466f8d1c2b65ff8..a95150196892b34d34898769b4de6132758b5c02 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -148,20 +148,26 @@ def show_result(img, else: bbox_result, segm_result = result, None bboxes = np.vstack(bbox_result) + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(bbox_result) + ] + labels = np.concatenate(labels) # draw segmentation masks if segm_result is not None: segms = mmcv.concat_list(segm_result) inds = np.where(bboxes[:, -1] > score_thr)[0] + np.random.seed(42) + color_masks = [ + np.random.randint(0, 256, (1, 3), dtype=np.uint8) + for _ in range(max(labels) + 1) + ] for i in inds: - color_mask = np.random.randint(0, 256, (1, 3), dtype=np.uint8) + i = int(i) + color_mask = color_masks[labels[i]] mask = maskUtils.decode(segms[i]).astype(np.bool) img[mask] = img[mask] * 0.5 + color_mask * 0.5 # draw bounding boxes - labels = [ - np.full(bbox.shape[0], i, dtype=np.int32) - for i, bbox in enumerate(bbox_result) - ] - labels = np.concatenate(labels) mmcv.imshow_det_bboxes( img, bboxes,