From 8cf2de8a7b3a94a2c69cbe44a9c22921d9bdb45f Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Thu, 28 Nov 2019 23:23:06 +0800
Subject: [PATCH] Fix oom when there are too many gts (#1575)

* fix oom when there are too many gts

* add gpu_assign_thr

* add gpu assign thr to approx max iou assigner

* upgrade yapf format
---
 .../bbox/assigners/approx_max_iou_assigner.py | 23 ++++++++++++++++++-
 mmdet/core/bbox/assigners/max_iou_assigner.py | 23 ++++++++++++++++++-
 2 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
index 4ab5259..a52ed26 100644
--- a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
+++ b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
@@ -27,6 +27,9 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
             ignoring any bboxes.
         ignore_wrt_candidates (bool): Whether to compute the iof between
             `bboxes` and `gt_bboxes_ignore`, or the contrary.
+        gpu_assign_thr (int): The upper bound of the number of GT for GPU
+            assign. When the number of gt is above this threshold, will assign
+            on CPU device. Negative values mean not assign on CPU.
     """
 
     def __init__(self,
@@ -35,13 +38,15 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
                  min_pos_iou=.0,
                  gt_max_assign_all=True,
                  ignore_iof_thr=-1,
-                 ignore_wrt_candidates=True):
+                 ignore_wrt_candidates=True,
+                 gpu_assign_thr=-1):
         self.pos_iou_thr = pos_iou_thr
         self.neg_iou_thr = neg_iou_thr
         self.min_pos_iou = min_pos_iou
         self.gt_max_assign_all = gt_max_assign_all
         self.ignore_iof_thr = ignore_iof_thr
         self.ignore_wrt_candidates = ignore_wrt_candidates
+        self.gpu_assign_thr = gpu_assign_thr
 
     def assign(self,
                approxs,
@@ -90,6 +95,17 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
         approxs = torch.transpose(
             approxs.view(num_squares, approxs_per_octave, 4), 0,
             1).contiguous().view(-1, 4)
+        assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
+            num_gts > self.gpu_assign_thr) else False
+        # compute overlap and assign gt on CPU when number of GT is large
+        if assign_on_cpu:
+            device = approxs.device
+            approxs = approxs.cpu()
+            gt_bboxes = gt_bboxes.cpu()
+            if gt_bboxes_ignore is not None:
+                gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+            if gt_labels is not None:
+                gt_labels = gt_labels.cpu()
         all_overlaps = bbox_overlaps(approxs, gt_bboxes)
 
         overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
@@ -111,4 +127,9 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
             overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
 
         assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+        if assign_on_cpu:
+            assign_result.gt_inds = assign_result.gt_inds.to(device)
+            assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+            if assign_result.labels is not None:
+                assign_result.labels = assign_result.labels.to(device)
         return assign_result
diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py
index 5c93b36..3eb8936 100644
--- a/mmdet/core/bbox/assigners/max_iou_assigner.py
+++ b/mmdet/core/bbox/assigners/max_iou_assigner.py
@@ -28,6 +28,9 @@ class MaxIoUAssigner(BaseAssigner):
             ignoring any bboxes.
         ignore_wrt_candidates (bool): Whether to compute the iof between
             `bboxes` and `gt_bboxes_ignore`, or the contrary.
+        gpu_assign_thr (int): The upper bound of the number of GT for GPU
+            assign. When the number of gt is above this threshold, will assign
+            on CPU device. Negative values mean not assign on CPU.
     """
 
     def __init__(self,
@@ -36,13 +39,15 @@ class MaxIoUAssigner(BaseAssigner):
                  min_pos_iou=.0,
                  gt_max_assign_all=True,
                  ignore_iof_thr=-1,
-                 ignore_wrt_candidates=True):
+                 ignore_wrt_candidates=True,
+                 gpu_assign_thr=-1):
         self.pos_iou_thr = pos_iou_thr
         self.neg_iou_thr = neg_iou_thr
         self.min_pos_iou = min_pos_iou
         self.gt_max_assign_all = gt_max_assign_all
         self.ignore_iof_thr = ignore_iof_thr
         self.ignore_wrt_candidates = ignore_wrt_candidates
+        self.gpu_assign_thr = gpu_assign_thr
 
     def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
         """Assign gt to bboxes.
@@ -72,6 +77,17 @@ class MaxIoUAssigner(BaseAssigner):
         """
         if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
             raise ValueError('No gt or bboxes')
+        assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
+            gt_bboxes.shape[0] > self.gpu_assign_thr) else False
+        # compute overlap and assign gt on CPU when number of GT is large
+        if assign_on_cpu:
+            device = bboxes.device
+            bboxes = bboxes.cpu()
+            gt_bboxes = gt_bboxes.cpu()
+            if gt_bboxes_ignore is not None:
+                gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+            if gt_labels is not None:
+                gt_labels = gt_labels.cpu()
         bboxes = bboxes[:, :4]
         overlaps = bbox_overlaps(gt_bboxes, bboxes)
 
@@ -88,6 +104,11 @@ class MaxIoUAssigner(BaseAssigner):
             overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
 
         assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+        if assign_on_cpu:
+            assign_result.gt_inds = assign_result.gt_inds.to(device)
+            assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+            if assign_result.labels is not None:
+                assign_result.labels = assign_result.labels.to(device)
         return assign_result
 
     def assign_wrt_overlaps(self, overlaps, gt_labels=None):
-- 
GitLab