Skip to content
Snippets Groups Projects
Commit 801c8b19 authored by Kai Chen's avatar Kai Chen
Browse files

bug fix and support different iof computation

parent d1cf5e59
No related branches found
No related tags found
No related merge requests found
...@@ -26,6 +26,8 @@ class MaxIoUAssigner(BaseAssigner): ...@@ -26,6 +26,8 @@ class MaxIoUAssigner(BaseAssigner):
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
`gt_bboxes_ignore` is specified). Negative values mean not `gt_bboxes_ignore` is specified). Negative values mean not
ignoring any bboxes. ignoring any bboxes.
ignore_wrt_candidates (bool): Whether to compute the iof between
`bboxes` and `gt_bboxes_ignore`, or the contrary.
""" """
def __init__(self, def __init__(self,
...@@ -33,12 +35,14 @@ class MaxIoUAssigner(BaseAssigner): ...@@ -33,12 +35,14 @@ class MaxIoUAssigner(BaseAssigner):
neg_iou_thr, neg_iou_thr,
min_pos_iou=.0, min_pos_iou=.0,
gt_max_assign_all=True, gt_max_assign_all=True,
ignore_iof_thr=-1): ignore_iof_thr=-1,
ignore_wrt_candidates=True):
self.pos_iou_thr = pos_iou_thr self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr self.neg_iou_thr = neg_iou_thr
self.min_pos_iou = min_pos_iou self.min_pos_iou = min_pos_iou
self.gt_max_assign_all = gt_max_assign_all self.gt_max_assign_all = gt_max_assign_all
self.ignore_iof_thr = ignore_iof_thr self.ignore_iof_thr = ignore_iof_thr
self.ignore_wrt_candidates = ignore_wrt_candidates
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to bboxes. """Assign gt to bboxes.
...@@ -73,13 +77,15 @@ class MaxIoUAssigner(BaseAssigner): ...@@ -73,13 +77,15 @@ class MaxIoUAssigner(BaseAssigner):
if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and ( if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and (
gt_bboxes_ignore.numel() > 0): gt_bboxes_ignore.numel() > 0):
ignore_overlaps = bbox_overlaps( if self.ignore_wrt_candidates:
bboxes, gt_bboxes_ignore, mode='iof') ignore_overlaps = bbox_overlaps(
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) bboxes, gt_bboxes_ignore, mode='iof')
ignore_bboxes_inds = torch.nonzero( ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
ignore_max_overlaps > self.ignore_iof_thr).squeeze() else:
if ignore_bboxes_inds.numel() > 0: ignore_overlaps = bbox_overlaps(
overlaps[ignore_bboxes_inds[:, 0], :] = -1 gt_bboxes_ignore, bboxes, mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
return assign_result return assign_result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment