From c514be0d79bd132e663c7ab4a63ef52b39f319ff Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Wed, 23 Oct 2019 18:48:59 +0800 Subject: [PATCH] Add seg transform to Expand and MinIoUCrop (#1550) * Add seg transform to Expand and MinIoUCrop * fix bug of repr --- mmdet/datasets/pipelines/transforms.py | 29 +++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index c758c91..b38884e 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -511,12 +511,19 @@ class Expand(object): ratio_range (tuple): range of expand ratio. """ - def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): + def __init__(self, + mean=(0, 0, 0), + to_rgb=True, + ratio_range=(1, 4), + seg_ignore_label=None): + self.to_rgb = to_rgb + self.ratio_range = ratio_range if to_rgb: self.mean = mean[::-1] else: self.mean = mean self.min_ratio, self.max_ratio = ratio_range + self.seg_ignore_label = seg_ignore_label def __call__(self, results): if random.randint(2): @@ -544,12 +551,23 @@ class Expand(object): expand_mask[top:top + h, left:left + w] = mask expand_gt_masks.append(expand_mask) results['gt_masks'] = expand_gt_masks + + # not tested + if 'gt_semantic_seg' in results: + assert self.seg_ignore_label is not None + gt_seg = results['gt_semantic_seg'] + expand_gt_seg = np.full((int(h * ratio), int(w * ratio)), + self.seg_ignore_label).astype(gt_seg.dtype) + expand_gt_seg[top:top + h, left:left + w] = gt_seg + results['gt_semantic_seg'] = expand_gt_seg return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += '(mean={}, to_rgb={}, ratio_range={})'.format( - self.mean, self.to_rgb, self.ratio_range) + repr_str += '(mean={}, to_rgb={}, ratio_range={}, ' \ + 'seg_ignore_label={})'.format( + self.mean, self.to_rgb, self.ratio_range, + self.seg_ignore_label) return repr_str @@ -626,6 +644,11 @@ class MinIoURandomCrop(object): gt_mask[patch[1]:patch[3], patch[0]:patch[2]] for gt_mask in valid_masks ] + + # not tested + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = results['gt_semantic_seg'][ + patch[1]:patch[3], patch[0]:patch[2]] return results def __repr__(self): -- GitLab