diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py index 91238e6bdf9101ae37eee1b39f9b640a45b68134..bba372ffae5d6ffb43bdb21bd76eb3e3e2725f5a 100644 --- a/mmdet/core/anchor/anchor_target.py +++ b/mmdet/core/anchor/anchor_target.py @@ -164,10 +164,10 @@ def anchor_inside_flags(flat_anchors, valid_flags, img_shape, img_h, img_w = img_shape[:2] if allowed_border >= 0: inside_flags = valid_flags & \ - (flat_anchors[:, 0] >= -allowed_border) & \ - (flat_anchors[:, 1] >= -allowed_border) & \ - (flat_anchors[:, 2] < img_w + allowed_border) & \ - (flat_anchors[:, 3] < img_h + allowed_border) + (flat_anchors[:, 0] >= -allowed_border).type(torch.uint8) & \ + (flat_anchors[:, 1] >= -allowed_border).type(torch.uint8) & \ + (flat_anchors[:, 2] < img_w + allowed_border).type(torch.uint8) & \ + (flat_anchors[:, 3] < img_h + allowed_border).type(torch.uint8) else: inside_flags = valid_flags return inside_flags