Skip to content
Snippets Groups Projects
Commit c514be0d authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Add seg transform to Expand and MinIoUCrop (#1550)

* Add seg transform to Expand and MinIoUCrop

* fix bug of repr
parent fa2db159
No related branches found
No related tags found
No related merge requests found
...@@ -511,12 +511,19 @@ class Expand(object): ...@@ -511,12 +511,19 @@ class Expand(object):
ratio_range (tuple): range of expand ratio. ratio_range (tuple): range of expand ratio.
""" """
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): def __init__(self,
mean=(0, 0, 0),
to_rgb=True,
ratio_range=(1, 4),
seg_ignore_label=None):
self.to_rgb = to_rgb
self.ratio_range = ratio_range
if to_rgb: if to_rgb:
self.mean = mean[::-1] self.mean = mean[::-1]
else: else:
self.mean = mean self.mean = mean
self.min_ratio, self.max_ratio = ratio_range self.min_ratio, self.max_ratio = ratio_range
self.seg_ignore_label = seg_ignore_label
def __call__(self, results): def __call__(self, results):
if random.randint(2): if random.randint(2):
...@@ -544,12 +551,23 @@ class Expand(object): ...@@ -544,12 +551,23 @@ class Expand(object):
expand_mask[top:top + h, left:left + w] = mask expand_mask[top:top + h, left:left + w] = mask
expand_gt_masks.append(expand_mask) expand_gt_masks.append(expand_mask)
results['gt_masks'] = expand_gt_masks results['gt_masks'] = expand_gt_masks
# not tested
if 'gt_semantic_seg' in results:
assert self.seg_ignore_label is not None
gt_seg = results['gt_semantic_seg']
expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
self.seg_ignore_label).astype(gt_seg.dtype)
expand_gt_seg[top:top + h, left:left + w] = gt_seg
results['gt_semantic_seg'] = expand_gt_seg
return results return results
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += '(mean={}, to_rgb={}, ratio_range={})'.format( repr_str += '(mean={}, to_rgb={}, ratio_range={}, ' \
self.mean, self.to_rgb, self.ratio_range) 'seg_ignore_label={})'.format(
self.mean, self.to_rgb, self.ratio_range,
self.seg_ignore_label)
return repr_str return repr_str
...@@ -626,6 +644,11 @@ class MinIoURandomCrop(object): ...@@ -626,6 +644,11 @@ class MinIoURandomCrop(object):
gt_mask[patch[1]:patch[3], patch[0]:patch[2]] gt_mask[patch[1]:patch[3], patch[0]:patch[2]]
for gt_mask in valid_masks for gt_mask in valid_masks
] ]
# not tested
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = results['gt_semantic_seg'][
patch[1]:patch[3], patch[0]:patch[2]]
return results return results
def __repr__(self): def __repr__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment