From f97d361d1159246d51a22052baba8594f5ce4c9b Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Sat, 27 Jul 2019 17:38:43 +0800
Subject: [PATCH] bug fix for mask resizing when aspect ratio is unfixed
 (#1050)

---
 mmdet/datasets/transforms.py | 26 ++++++++++++++++++++------
 1 file changed, 20 insertions(+), 6 deletions(-)

diff --git a/mmdet/datasets/transforms.py b/mmdet/datasets/transforms.py
index ff575db..aed6cff 100644
--- a/mmdet/datasets/transforms.py
+++ b/mmdet/datasets/transforms.py
@@ -34,8 +34,8 @@ class ImageTransform(object):
         else:
             img, w_scale, h_scale = mmcv.imresize(
                 img, scale, return_scale=True)
-            scale_factor = np.array(
-                [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
+            scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+                                    dtype=np.float32)
         img_shape = img.shape
         img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
         if flip:
@@ -99,10 +99,24 @@ class MaskTransform(object):
     """
 
     def __call__(self, masks, pad_shape, scale_factor, flip=False):
-        masks = [
-            mmcv.imrescale(mask, scale_factor, interpolation='nearest')
-            for mask in masks
-        ]
+        # aspect ratio unchanged
+        if isinstance(scale_factor, float):
+            masks = [
+                mmcv.imrescale(mask, scale_factor, interpolation='nearest')
+                for mask in masks
+            ]
+        # aspect ratio changed
+        else:
+            w_ratio, h_ratio = scale_factor[:2]
+            if masks:
+                h, w = masks[0].shape[:2]
+                new_h = int(np.round(h * h_ratio))
+                new_w = int(np.round(w * w_ratio))
+                new_size = (new_w, new_h)
+                masks = [
+                    mmcv.imresize(mask, new_size, interpolation='nearest')
+                    for mask in masks
+                ]
         if flip:
             masks = [mask[:, ::-1] for mask in masks]
         padded_masks = [
-- 
GitLab