diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py index dceee3126df4a37902d2b2e6d23b10434ca13e5d..57a1e750456da7cd5fa251aed3f36416563ce7d0 100644 --- a/mmdet/core/bbox/assigners/max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/max_iou_assigner.py @@ -26,6 +26,8 @@ class MaxIoUAssigner(BaseAssigner): ignore_iof_thr (float): IoF threshold for ignoring bboxes (if `gt_bboxes_ignore` is specified). Negative values mean not ignoring any bboxes. + ignore_wrt_candidates (bool): Whether to compute the iof between + `bboxes` and `gt_bboxes_ignore`, or the contrary. """ def __init__(self, @@ -33,12 +35,14 @@ class MaxIoUAssigner(BaseAssigner): neg_iou_thr, min_pos_iou=.0, gt_max_assign_all=True, - ignore_iof_thr=-1): + ignore_iof_thr=-1, + ignore_wrt_candidates=True): self.pos_iou_thr = pos_iou_thr self.neg_iou_thr = neg_iou_thr self.min_pos_iou = min_pos_iou self.gt_max_assign_all = gt_max_assign_all self.ignore_iof_thr = ignore_iof_thr + self.ignore_wrt_candidates = ignore_wrt_candidates def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): """Assign gt to bboxes. @@ -73,13 +77,15 @@ class MaxIoUAssigner(BaseAssigner): if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and ( gt_bboxes_ignore.numel() > 0): - ignore_overlaps = bbox_overlaps( - bboxes, gt_bboxes_ignore, mode='iof') - ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) - ignore_bboxes_inds = torch.nonzero( - ignore_max_overlaps > self.ignore_iof_thr).squeeze() - if ignore_bboxes_inds.numel() > 0: - overlaps[ignore_bboxes_inds[:, 0], :] = -1 + if self.ignore_wrt_candidates: + ignore_overlaps = bbox_overlaps( + bboxes, gt_bboxes_ignore, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + else: + ignore_overlaps = bbox_overlaps( + gt_bboxes_ignore, bboxes, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) + overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) return assign_result