diff --git a/mmdet/datasets/transforms.py b/mmdet/datasets/transforms.py index aed6cff1f0d732eb6341cb022fe672633f7c58c8..1e13d772ab19a8484f44832fd1bd76ff7a14fc2a 100644 --- a/mmdet/datasets/transforms.py +++ b/mmdet/datasets/transforms.py @@ -49,18 +49,23 @@ class ImageTransform(object): return img, img_shape, pad_shape, scale_factor -def bbox_flip(bboxes, img_shape): - """Flip bboxes horizontally. +def bbox_flip(bboxes, img_shape, direction='horizontal'): + """Flip bboxes horizontally or vertically. Args: bboxes(ndarray): shape (..., 4*k) img_shape(tuple): (height, width) """ assert bboxes.shape[-1] % 4 == 0 - w = img_shape[1] flipped = bboxes.copy() - flipped[..., 0::4] = w - bboxes[..., 2::4] - 1 - flipped[..., 2::4] = w - bboxes[..., 0::4] - 1 + if direction == 'horizontal': + w = img_shape[1] + flipped[..., 0::4] = w - bboxes[..., 2::4] - 1 + flipped[..., 2::4] = w - bboxes[..., 0::4] - 1 + else: + h = img_shape[0] + flipped[..., 1::4] = h - bboxes[..., 3::4] - 1 + flipped[..., 3::4] = h - bboxes[..., 1::4] - 1 return flipped