From cfc3cf0eeccee2978a03e3e002e8e174962b1b51 Mon Sep 17 00:00:00 2001
From: liushuchun <liuscgood@gmail.com>
Date: Wed, 7 Aug 2019 20:54:22 +0800
Subject: [PATCH] Fix scale test error (#883)

---
 mmdet/models/bbox_heads/bbox_head.py   |  5 ++++-
 mmdet/models/detectors/cascade_rcnn.py | 13 ++++++++++---
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 5c54828..6188fb8 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -152,7 +152,10 @@ class BBoxHead(nn.Module):
                 bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
 
         if rescale:
-            bboxes /= scale_factor
+            if isinstance(scale_factor, float):
+                bboxes /= scale_factor
+            else:
+                bboxes /= torch.from_numpy(scale_factor).to(bboxes.device)
 
         if cfg is None:
             return bboxes, scores
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 9bd0207..7ecbdff 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -323,9 +323,16 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
                 mask_classes = self.mask_head[-1].num_classes - 1
                 segm_result = [[] for _ in range(mask_classes)]
             else:
-                _bboxes = (
-                    det_bboxes[:, :4] *
-                    scale_factor if rescale else det_bboxes)
+                if isinstance(scale_factor, float):  # aspect ratio fixed
+                    _bboxes = (
+                        det_bboxes[:, :4] *
+                        scale_factor if rescale else det_bboxes)
+                else:
+                    _bboxes = (
+                        det_bboxes[:, :4] *
+                        torch.from_numpy(scale_factor).to(det_bboxes.device)
+                        if rescale else det_bboxes)
+
                 mask_rois = bbox2roi([_bboxes])
                 aug_masks = []
                 for i in range(self.num_stages):
-- 
GitLab