From 01ddb988e506855a5217584aaa1e85266fdf8f29 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Sun, 9 Dec 2018 15:58:26 +0800
Subject: [PATCH] add the gt_max_assign_all argument for MaxIoUAssigner

---
 mmdet/core/bbox/assigners/max_iou_assigner.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py
index c43db07..2f1b288 100644
--- a/mmdet/core/bbox/assigners/max_iou_assigner.py
+++ b/mmdet/core/bbox/assigners/max_iou_assigner.py
@@ -19,8 +19,10 @@ class MaxIoUAssigner(BaseAssigner):
         pos_iou_thr (float): IoU threshold for positive bboxes.
         neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
         min_pos_iou (float): Minimum iou for a bbox to be considered as a
-            positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
-            it is usually set as pos_iou_thr
+            positive bbox. Positive samples can have smaller IoU than
+            pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+        gt_max_assign_all (bool): Whether to assign all bboxes with the same
+            highest overlap with some gt to that gt.
         ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
             `gt_bboxes_ignore` is specified). Negative values mean not
             ignoring any bboxes.
@@ -30,10 +32,12 @@ class MaxIoUAssigner(BaseAssigner):
                  pos_iou_thr,
                  neg_iou_thr,
                  min_pos_iou=.0,
+                 gt_max_assign_all=True,
                  ignore_iof_thr=-1):
         self.pos_iou_thr = pos_iou_thr
         self.neg_iou_thr = neg_iou_thr
         self.min_pos_iou = min_pos_iou
+        self.gt_max_assign_all = gt_max_assign_all
         self.ignore_iof_thr = ignore_iof_thr
 
     def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
@@ -124,7 +128,11 @@ class MaxIoUAssigner(BaseAssigner):
         # 4. assign fg: for each gt, proposals with highest IoU
         for i in range(num_gts):
             if gt_max_overlaps[i] >= self.min_pos_iou:
-                assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
+                if self.gt_max_assign_all:
+                    max_iou_inds = overlaps[:, i] == gt_max_overlaps[i]
+                    assigned_gt_inds[max_iou_inds] = i + 1
+                else:
+                    assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
 
         if gt_labels is not None:
             assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
-- 
GitLab