From 9118a94a4a0eee4ae2a24908f6c6684fb9a76468 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Tue, 24 Dec 2019 17:03:34 +0800
Subject: [PATCH] Split SegResizeFlipPadRescale into different existing
 transforms (#1852)

* Split seg trans

* Modify cfg

* fix typo
---
 configs/hrnet/htc_hrnetv2p_w32_20e.py         |  2 +-
 ...-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py |  2 +-
 configs/htc/htc_r101_fpn_20e.py               |  2 +-
 configs/htc/htc_r50_fpn_1x.py                 |  2 +-
 configs/htc/htc_r50_fpn_20e.py                |  2 +-
 configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py   |  2 +-
 configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py   |  2 +-
 mmdet/datasets/custom.py                      |  1 +
 mmdet/datasets/pipelines/__init__.py          |  6 +--
 mmdet/datasets/pipelines/loading.py           |  1 +
 mmdet/datasets/pipelines/transforms.py        | 54 +++++++++----------
 11 files changed, 39 insertions(+), 37 deletions(-)

diff --git a/configs/hrnet/htc_hrnetv2p_w32_20e.py b/configs/hrnet/htc_hrnetv2p_w32_20e.py
index 8279f24..a148963 100644
--- a/configs/hrnet/htc_hrnetv2p_w32_20e.py
+++ b/configs/hrnet/htc_hrnetv2p_w32_20e.py
@@ -221,7 +221,7 @@ train_pipeline = [
     dict(type='RandomFlip', flip_ratio=0.5),
     dict(type='Normalize', **img_norm_cfg),
     dict(type='Pad', size_divisor=32),
-    dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
+    dict(type='SegRescale', scale_factor=1 / 8),
     dict(type='DefaultFormatBundle'),
     dict(
         type='Collect',
diff --git a/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py b/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py
index 2072f29..275b7fb 100644
--- a/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py
+++ b/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py
@@ -217,7 +217,7 @@ train_pipeline = [
     dict(type='RandomFlip', flip_ratio=0.5),
     dict(type='Normalize', **img_norm_cfg),
     dict(type='Pad', size_divisor=32),
-    dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
+    dict(type='SegRescale', scale_factor=1 / 8),
     dict(type='DefaultFormatBundle'),
     dict(
         type='Collect',
diff --git a/configs/htc/htc_r101_fpn_20e.py b/configs/htc/htc_r101_fpn_20e.py
index 661c564..4f0ec72 100644
--- a/configs/htc/htc_r101_fpn_20e.py
+++ b/configs/htc/htc_r101_fpn_20e.py
@@ -205,7 +205,7 @@ train_pipeline = [
     dict(type='RandomFlip', flip_ratio=0.5),
     dict(type='Normalize', **img_norm_cfg),
     dict(type='Pad', size_divisor=32),
-    dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
+    dict(type='SegRescale', scale_factor=1 / 8),
     dict(type='DefaultFormatBundle'),
     dict(
         type='Collect',
diff --git a/configs/htc/htc_r50_fpn_1x.py b/configs/htc/htc_r50_fpn_1x.py
index 4945f2e..34ac663 100644
--- a/configs/htc/htc_r50_fpn_1x.py
+++ b/configs/htc/htc_r50_fpn_1x.py
@@ -205,7 +205,7 @@ train_pipeline = [
     dict(type='RandomFlip', flip_ratio=0.5),
     dict(type='Normalize', **img_norm_cfg),
     dict(type='Pad', size_divisor=32),
-    dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
+    dict(type='SegRescale', scale_factor=1 / 8),
     dict(type='DefaultFormatBundle'),
     dict(
         type='Collect',
diff --git a/configs/htc/htc_r50_fpn_20e.py b/configs/htc/htc_r50_fpn_20e.py
index ccf73a0..47714ac 100644
--- a/configs/htc/htc_r50_fpn_20e.py
+++ b/configs/htc/htc_r50_fpn_20e.py
@@ -205,7 +205,7 @@ train_pipeline = [
     dict(type='RandomFlip', flip_ratio=0.5),
     dict(type='Normalize', **img_norm_cfg),
     dict(type='Pad', size_divisor=32),
-    dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
+    dict(type='SegRescale', scale_factor=1 / 8),
     dict(type='DefaultFormatBundle'),
     dict(
         type='Collect',
diff --git a/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py b/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py
index 915a54e..b211ba9 100644
--- a/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py
+++ b/configs/htc/htc_x101_32x4d_fpn_20e_16gpu.py
@@ -207,7 +207,7 @@ train_pipeline = [
     dict(type='RandomFlip', flip_ratio=0.5),
     dict(type='Normalize', **img_norm_cfg),
     dict(type='Pad', size_divisor=32),
-    dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
+    dict(type='SegRescale', scale_factor=1 / 8),
     dict(type='DefaultFormatBundle'),
     dict(
         type='Collect',
diff --git a/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py b/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py
index 99ceefc..0186d72 100644
--- a/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py
+++ b/configs/htc/htc_x101_64x4d_fpn_20e_16gpu.py
@@ -207,7 +207,7 @@ train_pipeline = [
     dict(type='RandomFlip', flip_ratio=0.5),
     dict(type='Normalize', **img_norm_cfg),
     dict(type='Pad', size_divisor=32),
-    dict(type='SegResizeFlipPadRescale', scale_factor=1 / 8),
+    dict(type='SegRescale', scale_factor=1 / 8),
     dict(type='DefaultFormatBundle'),
     dict(
         type='Collect',
diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py
index d068543..935b39d 100644
--- a/mmdet/datasets/custom.py
+++ b/mmdet/datasets/custom.py
@@ -98,6 +98,7 @@ class CustomDataset(Dataset):
         results['proposal_file'] = self.proposal_file
         results['bbox_fields'] = []
         results['mask_fields'] = []
+        results['seg_fields'] = []
 
     def _filter_imgs(self, min_size=32):
         """Filter images too small."""
diff --git a/mmdet/datasets/pipelines/__init__.py b/mmdet/datasets/pipelines/__init__.py
index ae55b25..bfe375c 100644
--- a/mmdet/datasets/pipelines/__init__.py
+++ b/mmdet/datasets/pipelines/__init__.py
@@ -5,12 +5,12 @@ from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals
 from .test_aug import MultiScaleFlipAug
 from .transforms import (Albu, Expand, MinIoURandomCrop, Normalize, Pad,
                          PhotoMetricDistortion, RandomCrop, RandomFlip, Resize,
-                         SegResizeFlipPadRescale)
+                         SegRescale)
 
 __all__ = [
     'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
     'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
     'LoadProposals', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad',
-    'RandomCrop', 'Normalize', 'SegResizeFlipPadRescale', 'MinIoURandomCrop',
-    'Expand', 'PhotoMetricDistortion', 'Albu'
+    'RandomCrop', 'Normalize', 'SegRescale', 'MinIoURandomCrop', 'Expand',
+    'PhotoMetricDistortion', 'Albu'
 ]
diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py
index 9f3007e..f4aa6de 100644
--- a/mmdet/datasets/pipelines/loading.py
+++ b/mmdet/datasets/pipelines/loading.py
@@ -91,6 +91,7 @@ class LoadAnnotations(object):
         results['gt_semantic_seg'] = mmcv.imread(
             osp.join(results['seg_prefix'], results['ann_info']['seg_map']),
             flag='unchanged').squeeze()
+        results['seg_fields'].append('gt_semantic_seg')
         return results
 
     def __call__(self, results):
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
index dc38597..f57eb4e 100644
--- a/mmdet/datasets/pipelines/transforms.py
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -149,12 +149,23 @@ class Resize(object):
                 ]
             results[key] = masks
 
+    def _resize_seg(self, results):
+        for key in results.get('seg_fields', []):
+            if self.keep_ratio:
+                gt_seg = mmcv.imrescale(
+                    results[key], results['scale'], interpolation='nearest')
+            else:
+                gt_seg = mmcv.imresize(
+                    results[key], results['scale'], interpolation='nearest')
+            results['gt_semantic_seg'] = gt_seg
+
     def __call__(self, results):
         if 'scale' not in results:
             self._random_scale(results)
         self._resize_img(results)
         self._resize_bboxes(results)
         self._resize_masks(results)
+        self._resize_seg(results)
         return results
 
     def __repr__(self):
@@ -229,6 +240,11 @@ class RandomFlip(object):
                     mmcv.imflip(mask, direction=results['flip_direction'])
                     for mask in results[key]
                 ]
+
+            # flip segs
+            for key in results.get('seg_fields', []):
+                results[key] = mmcv.imflip(
+                    results[key], direction=results['flip_direction'])
         return results
 
     def __repr__(self):
@@ -280,9 +296,14 @@ class Pad(object):
             else:
                 results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8)
 
+    def _pad_seg(self, results):
+        for key in results.get('seg_fields', []):
+            results[key] = mmcv.impad(results[key], results['pad_shape'][:2])
+
     def __call__(self, results):
         self._pad_img(results)
         self._pad_masks(results)
+        self._pad_seg(results)
         return results
 
     def __repr__(self):
@@ -386,15 +407,8 @@ class RandomCrop(object):
 
 
 @PIPELINES.register_module
-class SegResizeFlipPadRescale(object):
-    """A sequential transforms to semantic segmentation maps.
-
-    The same pipeline as input images is applied to the semantic segmentation
-    map, and finally rescale it by some scale factor. The transforms include:
-    1. resize
-    2. flip
-    3. pad
-    4. rescale (so that the final size can be different from the image size)
+class SegRescale(object):
+    """Rescale semantic segmentation maps.
 
     Args:
         scale_factor (float): The scale factor of the final output.
@@ -404,24 +418,10 @@ class SegResizeFlipPadRescale(object):
         self.scale_factor = scale_factor
 
     def __call__(self, results):
-        if results['keep_ratio']:
-            gt_seg = mmcv.imrescale(
-                results['gt_semantic_seg'],
-                results['scale'],
-                interpolation='nearest')
-        else:
-            gt_seg = mmcv.imresize(
-                results['gt_semantic_seg'],
-                results['scale'],
-                interpolation='nearest')
-        if results['flip']:
-            gt_seg = mmcv.imflip(gt_seg)
-        if gt_seg.shape != results['pad_shape']:
-            gt_seg = mmcv.impad(gt_seg, results['pad_shape'][:2])
-        if self.scale_factor != 1:
-            gt_seg = mmcv.imrescale(
-                gt_seg, self.scale_factor, interpolation='nearest')
-        results['gt_semantic_seg'] = gt_seg
+        for key in results.get('seg_fields', []):
+            if self.scale_factor != 1:
+                results[key] = mmcv.imrescale(
+                    results[key], self.scale_factor, interpolation='nearest')
         return results
 
     def __repr__(self):
-- 
GitLab