Skip to content
Snippets Groups Projects
Unverified Commit 1c9afecf authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

recover PR#1115 to add flip direction (#1273)

parent d0c04187
No related branches found
No related tags found
No related merge requests found
......@@ -179,12 +179,14 @@ class RandomFlip(object):
flip_ratio (float, optional): The flipping probability.
"""
def __init__(self, flip_ratio=None):
def __init__(self, flip_ratio=None, direction='horizontal'):
self.flip_ratio = flip_ratio
self.direction = direction
if flip_ratio is not None:
assert flip_ratio >= 0 and flip_ratio <= 1
assert direction in ['horizontal', 'vertical']
def bbox_flip(self, bboxes, img_shape):
def bbox_flip(self, bboxes, img_shape, direction):
"""Flip bboxes horizontally.
Args:
......@@ -192,26 +194,41 @@ class RandomFlip(object):
img_shape(tuple): (height, width)
"""
assert bboxes.shape[-1] % 4 == 0
w = img_shape[1]
flipped = bboxes.copy()
flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
if direction == 'horizontal':
w = img_shape[1]
flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
elif direction == 'vertical':
h = img_shape[0]
flipped[..., 1::4] = h - bboxes[..., 3::4] - 1
flipped[..., 3::4] = h - bboxes[..., 1::4] - 1
else:
raise ValueError(
'Invalid flipping direction "{}"'.format(direction))
return flipped
def __call__(self, results):
if 'flip' not in results:
flip = True if np.random.rand() < self.flip_ratio else False
results['flip'] = flip
if 'flip_direction' not in results:
results['flip_direction'] = self.direction
if results['flip']:
# flip image
results['img'] = mmcv.imflip(results['img'])
results['img'] = mmcv.imflip(
results['img'], direction=results['flip_direction'])
# flip bboxes
for key in results.get('bbox_fields', []):
results[key] = self.bbox_flip(results[key],
results['img_shape'])
results['img_shape'],
results['flip_direction'])
# flip masks
for key in results.get('mask_fields', []):
results[key] = [mask[:, ::-1] for mask in results[key]]
results[key] = [
mmcv.imflip(mask, direction=results['flip_direction'])
for mask in results[key]
]
return results
def __repr__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment