diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 45a383bba9c04b769a967c26eeb8af4fba6d08d7..60b86e4ff0acba20c74b7e58bea81fa513b2c05f 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -531,6 +531,15 @@ class Expand(object): results['img'] = expand_img results['gt_bboxes'] = boxes + + if 'gt_masks' in results: + expand_gt_masks = [] + for mask in results['gt_masks']: + expand_mask = np.full((int(h * ratio), int(w * ratio)), + 0).astype(mask.dtype) + expand_mask[top:top + h, left:left + w] = mask + expand_gt_masks.append(expand_mask) + results['gt_masks'] = expand_gt_masks return results def __repr__(self): @@ -604,6 +613,16 @@ class MinIoURandomCrop(object): results['img'] = img results['gt_bboxes'] = boxes results['gt_labels'] = labels + + if 'gt_masks' in results: + valid_masks = [ + results['gt_masks'][i] for i in range(len(mask)) + if mask[i] + ] + results['gt_masks'] = [ + gt_mask[patch[1]:patch[3], patch[0]:patch[2]] + for gt_mask in valid_masks + ] return results def __repr__(self):