diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 4333b811bdeee15984f752804ffeeae77ffb38b3..334581b4fbd9e54070829d5ad50a6eb4120b78b8 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -3,8 +3,9 @@ from __future__ import division import torch import torch.nn as nn -from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler, - merge_aug_masks) +from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner, + build_sampler, merge_aug_bboxes, merge_aug_masks, + multiclass_nms) from .. import builder from ..registry import DETECTORS from .base import BaseDetector @@ -399,8 +400,110 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): return results - def aug_test(self, img, img_meta, proposals=None, rescale=False): - raise NotImplementedError + def aug_test(self, imgs, img_metas, proposals=None, rescale=False): + """Test with augmentations. + + If rescale is False, then returned bboxes and masks will fit the scale + of imgs[0]. + """ + # recompute feats to save memory + proposal_list = self.aug_test_rpn( + self.extract_feats(imgs), img_metas, self.test_cfg.rpn) + + rcnn_test_cfg = self.test_cfg.rcnn + aug_bboxes = [] + aug_scores = [] + for x, img_meta in zip(self.extract_feats(imgs), img_metas): + # only one image in the batch + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + flip = img_meta[0]['flip'] + + proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, + scale_factor, flip) + # "ms" in variable names means multi-stage + ms_scores = [] + + rois = bbox2roi([proposals]) + for i in range(self.num_stages): + bbox_roi_extractor = self.bbox_roi_extractor[i] + bbox_head = self.bbox_head[i] + + bbox_feats = bbox_roi_extractor( + x[:len(bbox_roi_extractor.featmap_strides)], rois) + if self.with_shared_head: + bbox_feats = self.shared_head(bbox_feats) + + cls_score, bbox_pred = bbox_head(bbox_feats) + ms_scores.append(cls_score) + + if i < self.num_stages - 1: + bbox_label = cls_score.argmax(dim=1) + rois = bbox_head.regress_by_class(rois, bbox_label, + bbox_pred, img_meta[0]) + + cls_score = sum(ms_scores) / float(len(ms_scores)) + bboxes, scores = self.bbox_head[-1].get_det_bboxes( + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=False, + cfg=None) + aug_bboxes.append(bboxes) + aug_scores.append(scores) + + # after merging, bboxes will be rescaled to the original image size + merged_bboxes, merged_scores = merge_aug_bboxes( + aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) + det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, + rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms, + rcnn_test_cfg.max_per_img) + + bbox_result = bbox2result(det_bboxes, det_labels, + self.bbox_head[-1].num_classes) + + if self.with_mask: + if det_bboxes.shape[0] == 0: + segm_result = [[] + for _ in range(self.mask_head[-1].num_classes - + 1)] + else: + aug_masks = [] + aug_img_metas = [] + for x, img_meta in zip(self.extract_feats(imgs), img_metas): + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + flip = img_meta[0]['flip'] + _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, + scale_factor, flip) + mask_rois = bbox2roi([_bboxes]) + for i in range(self.num_stages): + mask_feats = self.mask_roi_extractor[i]( + x[:len(self.mask_roi_extractor[i].featmap_strides + )], mask_rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + mask_pred = self.mask_head[i](mask_feats) + aug_masks.append(mask_pred.sigmoid().cpu().numpy()) + aug_img_metas.append(img_meta) + merged_masks = merge_aug_masks(aug_masks, aug_img_metas, + self.test_cfg.rcnn) + + ori_shape = img_metas[0][0]['ori_shape'] + segm_result = self.mask_head[-1].get_seg_masks( + merged_masks, + det_bboxes, + det_labels, + rcnn_test_cfg, + ori_shape, + scale_factor=1.0, + rescale=False) + return bbox_result, segm_result + else: + return bbox_result def show_result(self, data, result, **kwargs): if self.with_mask: diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py index d0a70246d1c0f08961a8988bfa7521be7c622ce4..6f8eef50146730d9e9d89c9ba99c60c8fd39abad 100644 --- a/mmdet/models/detectors/htc.py +++ b/mmdet/models/detectors/htc.py @@ -1,8 +1,9 @@ import torch import torch.nn.functional as F -from mmdet.core import (bbox2result, bbox2roi, build_assigner, build_sampler, - merge_aug_masks) +from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner, + build_sampler, merge_aug_bboxes, merge_aug_masks, + multiclass_nms) from .. import builder from ..registry import DETECTORS from .cascade_rcnn import CascadeRCNN @@ -431,5 +432,124 @@ class HybridTaskCascade(CascadeRCNN): return results - def aug_test(self, img, img_meta, proposals=None, rescale=False): - raise NotImplementedError + def aug_test(self, imgs, img_metas, proposals=None, rescale=False): + """Test with augmentations. + + If rescale is False, then returned bboxes and masks will fit the scale + of imgs[0]. + """ + if self.with_semantic: + semantic_feats = [ + self.semantic_head(feat)[1] + for feat in self.extract_feats(imgs) + ] + else: + semantic_feats = [None] * len(img_metas) + + # recompute feats to save memory + proposal_list = self.aug_test_rpn( + self.extract_feats(imgs), img_metas, self.test_cfg.rpn) + + rcnn_test_cfg = self.test_cfg.rcnn + aug_bboxes = [] + aug_scores = [] + for x, img_meta, semantic in zip( + self.extract_feats(imgs), img_metas, semantic_feats): + # only one image in the batch + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + flip = img_meta[0]['flip'] + + proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, + scale_factor, flip) + # "ms" in variable names means multi-stage + ms_scores = [] + + rois = bbox2roi([proposals]) + for i in range(self.num_stages): + bbox_head = self.bbox_head[i] + cls_score, bbox_pred = self._bbox_forward_test( + i, x, rois, semantic_feat=semantic) + ms_scores.append(cls_score) + + if i < self.num_stages - 1: + bbox_label = cls_score.argmax(dim=1) + rois = bbox_head.regress_by_class(rois, bbox_label, + bbox_pred, img_meta[0]) + + cls_score = sum(ms_scores) / float(len(ms_scores)) + bboxes, scores = self.bbox_head[-1].get_det_bboxes( + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=False, + cfg=None) + aug_bboxes.append(bboxes) + aug_scores.append(scores) + + # after merging, bboxes will be rescaled to the original image size + merged_bboxes, merged_scores = merge_aug_bboxes( + aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) + det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, + rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms, + rcnn_test_cfg.max_per_img) + + bbox_result = bbox2result(det_bboxes, det_labels, + self.bbox_head[-1].num_classes) + + if self.with_mask: + if det_bboxes.shape[0] == 0: + segm_result = [[] + for _ in range(self.mask_head[-1].num_classes - + 1)] + else: + aug_masks = [] + aug_img_metas = [] + for x, img_meta, semantic in zip( + self.extract_feats(imgs), img_metas, semantic_feats): + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + flip = img_meta[0]['flip'] + _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, + scale_factor, flip) + mask_rois = bbox2roi([_bboxes]) + mask_feats = self.mask_roi_extractor[-1]( + x[:len(self.mask_roi_extractor[-1].featmap_strides)], + mask_rois) + if self.with_semantic: + semantic_feat = semantic + mask_semantic_feat = self.semantic_roi_extractor( + [semantic_feat], mask_rois) + if mask_semantic_feat.shape[-2:] != mask_feats.shape[ + -2:]: + mask_semantic_feat = F.adaptive_avg_pool2d( + mask_semantic_feat, mask_feats.shape[-2:]) + mask_feats += mask_semantic_feat + last_feat = None + for i in range(self.num_stages): + mask_head = self.mask_head[i] + if self.mask_info_flow: + mask_pred, last_feat = mask_head( + mask_feats, last_feat) + else: + mask_pred = mask_head(mask_feats) + aug_masks.append(mask_pred.sigmoid().cpu().numpy()) + aug_img_metas.append(img_meta) + merged_masks = merge_aug_masks(aug_masks, aug_img_metas, + self.test_cfg.rcnn) + + ori_shape = img_metas[0][0]['ori_shape'] + segm_result = self.mask_head[-1].get_seg_masks( + merged_masks, + det_bboxes, + det_labels, + rcnn_test_cfg, + ori_shape, + scale_factor=1.0, + rescale=False) + return bbox_result, segm_result + else: + return bbox_result