From cfc3cf0eeccee2978a03e3e002e8e174962b1b51 Mon Sep 17 00:00:00 2001 From: liushuchun <liuscgood@gmail.com> Date: Wed, 7 Aug 2019 20:54:22 +0800 Subject: [PATCH] Fix scale test error (#883) --- mmdet/models/bbox_heads/bbox_head.py | 5 ++++- mmdet/models/detectors/cascade_rcnn.py | 13 ++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index 5c54828..6188fb8 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -152,7 +152,10 @@ class BBoxHead(nn.Module): bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1) if rescale: - bboxes /= scale_factor + if isinstance(scale_factor, float): + bboxes /= scale_factor + else: + bboxes /= torch.from_numpy(scale_factor).to(bboxes.device) if cfg is None: return bboxes, scores diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 9bd0207..7ecbdff 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -323,9 +323,16 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): mask_classes = self.mask_head[-1].num_classes - 1 segm_result = [[] for _ in range(mask_classes)] else: - _bboxes = ( - det_bboxes[:, :4] * - scale_factor if rescale else det_bboxes) + if isinstance(scale_factor, float): # aspect ratio fixed + _bboxes = ( + det_bboxes[:, :4] * + scale_factor if rescale else det_bboxes) + else: + _bboxes = ( + det_bboxes[:, :4] * + torch.from_numpy(scale_factor).to(det_bboxes.device) + if rescale else det_bboxes) + mask_rois = bbox2roi([_bboxes]) aug_masks = [] for i in range(self.num_stages): -- GitLab