diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index b38884e29bf1315aff10f4555f4cbb38504a7674..e0cca25aa408def66033ea7f46fe4aa038b2fd73 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -509,13 +509,15 @@ class Expand(object): mean (tuple): mean value of dataset. to_rgb (bool): if need to convert the order of mean to align with RGB. ratio_range (tuple): range of expand ratio. + prob (float): probability of applying this transformation """ def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4), - seg_ignore_label=None): + seg_ignore_label=None, + prob=0.5): self.to_rgb = to_rgb self.ratio_range = ratio_range if to_rgb: @@ -524,9 +526,10 @@ class Expand(object): self.mean = mean self.min_ratio, self.max_ratio = ratio_range self.seg_ignore_label = seg_ignore_label + self.prob = prob def __call__(self, results): - if random.randint(2): + if random.uniform(0, 1) > self.prob: return results img, boxes = [results[k] for k in ('img', 'gt_bboxes')]