diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index c758c919d31fe65e6e0c91242dac2462238525a1..b38884e29bf1315aff10f4555f4cbb38504a7674 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -511,12 +511,19 @@ class Expand(object): ratio_range (tuple): range of expand ratio. """ - def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): + def __init__(self, + mean=(0, 0, 0), + to_rgb=True, + ratio_range=(1, 4), + seg_ignore_label=None): + self.to_rgb = to_rgb + self.ratio_range = ratio_range if to_rgb: self.mean = mean[::-1] else: self.mean = mean self.min_ratio, self.max_ratio = ratio_range + self.seg_ignore_label = seg_ignore_label def __call__(self, results): if random.randint(2): @@ -544,12 +551,23 @@ class Expand(object): expand_mask[top:top + h, left:left + w] = mask expand_gt_masks.append(expand_mask) results['gt_masks'] = expand_gt_masks + + # not tested + if 'gt_semantic_seg' in results: + assert self.seg_ignore_label is not None + gt_seg = results['gt_semantic_seg'] + expand_gt_seg = np.full((int(h * ratio), int(w * ratio)), + self.seg_ignore_label).astype(gt_seg.dtype) + expand_gt_seg[top:top + h, left:left + w] = gt_seg + results['gt_semantic_seg'] = expand_gt_seg return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += '(mean={}, to_rgb={}, ratio_range={})'.format( - self.mean, self.to_rgb, self.ratio_range) + repr_str += '(mean={}, to_rgb={}, ratio_range={}, ' \ + 'seg_ignore_label={})'.format( + self.mean, self.to_rgb, self.ratio_range, + self.seg_ignore_label) return repr_str @@ -626,6 +644,11 @@ class MinIoURandomCrop(object): gt_mask[patch[1]:patch[3], patch[0]:patch[2]] for gt_mask in valid_masks ] + + # not tested + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = results['gt_semantic_seg'][ + patch[1]:patch[3], patch[0]:patch[2]] return results def __repr__(self):