diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 6d26dc3a5aba773e521f3ffdcaf9ee7958b88843..cbaf4349d2da9c9bec2d7fe3847157c77463faef 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod import mmcv import numpy as np import torch.nn as nn +import pycocotools.mask as maskUtils from mmdet.core import tensor2imgs, get_classes @@ -86,6 +87,11 @@ class BaseDetector(nn.Module): img_norm_cfg, dataset='coco', score_thr=0.3): + if isinstance(result, tuple): + bbox_result, segm_result = result + else: + bbox_result, segm_result = result, None + img_tensor = data['img'][0] img_metas = data['img_meta'][0].data[0] imgs = tensor2imgs(img_tensor, **img_norm_cfg) @@ -102,12 +108,23 @@ class BaseDetector(nn.Module): for img, img_meta in zip(imgs, img_metas): h, w, _ = img_meta['img_shape'] img_show = img[:h, :w, :] + + bboxes = np.vstack(bbox_result) + # draw segmentation masks + if segm_result is not None: + segms = mmcv.concat_list(segm_result) + inds = np.where(bboxes[:, -1] > score_thr)[0] + for i in inds: + color_mask = np.random.randint( + 0, 256, (1, 3), dtype=np.uint8) + mask = maskUtils.decode(segms[i]).astype(np.bool) + img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5 + # draw bounding boxes labels = [ np.full(bbox.shape[0], i, dtype=np.int32) - for i, bbox in enumerate(result) + for i, bbox in enumerate(bbox_result) ] labels = np.concatenate(labels) - bboxes = np.vstack(result) mmcv.imshow_det_bboxes( img_show, bboxes, diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index c2ab5f9b5a1d2d5af314e2c5d7407e16192f19a2..5f843e61b56332111c22fc32914c7a11ad259261 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -306,14 +306,13 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): raise NotImplementedError def show_result(self, data, result, img_norm_cfg, **kwargs): - # TODO: show segmentation masks if self.with_mask: ms_bbox_result, ms_segm_result = result + if isinstance(ms_bbox_result, dict): + result = (ms_bbox_result['ensemble'], + ms_segm_result['ensemble']) else: - ms_bbox_result = result - if isinstance(ms_bbox_result, dict): - bbox_result = ms_bbox_result['ensemble'] - else: - bbox_result = ms_bbox_result - super(CascadeRCNN, self).show_result(data, bbox_result, img_norm_cfg, + if isinstance(result, dict): + result = result['ensemble'] + super(CascadeRCNN, self).show_result(data, result, img_norm_cfg, **kwargs) diff --git a/mmdet/models/detectors/mask_rcnn.py b/mmdet/models/detectors/mask_rcnn.py index 25a363e398f6c0d01e2f8bd53e05c9046a5275ac..0e308ab1e48ee67b91878c4ba570be43c250236f 100644 --- a/mmdet/models/detectors/mask_rcnn.py +++ b/mmdet/models/detectors/mask_rcnn.py @@ -25,10 +25,3 @@ class MaskRCNN(TwoStageDetector): train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained) - - def show_result(self, data, result, img_norm_cfg, **kwargs): - # TODO: show segmentation masks - assert isinstance(result, tuple) - assert len(result) == 2 # (bbox_results, segm_results) - super(MaskRCNN, self).show_result(data, result[0], img_norm_cfg, - **kwargs)