From 801c8b19931fb40774eda6dbb6917b6d1085ce8a Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Wed, 13 Feb 2019 16:22:54 +0800
Subject: [PATCH] bug fix and support different iof computation

---
 mmdet/core/bbox/assigners/max_iou_assigner.py | 22 ++++++++++++-------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py
index dceee31..57a1e75 100644
--- a/mmdet/core/bbox/assigners/max_iou_assigner.py
+++ b/mmdet/core/bbox/assigners/max_iou_assigner.py
@@ -26,6 +26,8 @@ class MaxIoUAssigner(BaseAssigner):
         ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
             `gt_bboxes_ignore` is specified). Negative values mean not
             ignoring any bboxes.
+        ignore_wrt_candidates (bool): Whether to compute the iof between
+            `bboxes` and `gt_bboxes_ignore`, or the contrary.
     """
 
     def __init__(self,
@@ -33,12 +35,14 @@ class MaxIoUAssigner(BaseAssigner):
                  neg_iou_thr,
                  min_pos_iou=.0,
                  gt_max_assign_all=True,
-                 ignore_iof_thr=-1):
+                 ignore_iof_thr=-1,
+                 ignore_wrt_candidates=True):
         self.pos_iou_thr = pos_iou_thr
         self.neg_iou_thr = neg_iou_thr
         self.min_pos_iou = min_pos_iou
         self.gt_max_assign_all = gt_max_assign_all
         self.ignore_iof_thr = ignore_iof_thr
+        self.ignore_wrt_candidates = ignore_wrt_candidates
 
     def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
         """Assign gt to bboxes.
@@ -73,13 +77,15 @@ class MaxIoUAssigner(BaseAssigner):
 
         if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and (
                 gt_bboxes_ignore.numel() > 0):
-            ignore_overlaps = bbox_overlaps(
-                bboxes, gt_bboxes_ignore, mode='iof')
-            ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
-            ignore_bboxes_inds = torch.nonzero(
-                ignore_max_overlaps > self.ignore_iof_thr).squeeze()
-            if ignore_bboxes_inds.numel() > 0:
-                overlaps[ignore_bboxes_inds[:, 0], :] = -1
+            if self.ignore_wrt_candidates:
+                ignore_overlaps = bbox_overlaps(
+                    bboxes, gt_bboxes_ignore, mode='iof')
+                ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+            else:
+                ignore_overlaps = bbox_overlaps(
+                    gt_bboxes_ignore, bboxes, mode='iof')
+                ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
+            overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
 
         assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
         return assign_result
-- 
GitLab