diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 18b1cc10a3c0153118aa36a3b917d405d46d86e5..45a383bba9c04b769a967c26eeb8af4fba6d08d7 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -300,7 +300,7 @@ class Normalize(object): @PIPELINES.register_module class RandomCrop(object): - """Random crop the image & bboxes. + """Random crop the image & bboxes & masks. Args: crop_size (tuple): Expected size after cropping, (h, w). @@ -348,7 +348,7 @@ class RandomCrop(object): # filter and crop the masks if 'gt_masks' in results: valid_gt_masks = [] - for i in valid_inds: + for i in np.where(valid_inds)[0]: gt_mask = results['gt_masks'][i][crop_y1:crop_y2, crop_x1: crop_x2] valid_gt_masks.append(gt_mask)