Skip to content
Snippets Groups Projects
Commit 8cf2de8a authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Fix oom when there are too many gts (#1575)

* fix oom when there are too many gts

* add gpu_assign_thr

* add gpu assign thr to approx max iou assigner

* upgrade yapf format
parent f1c79392
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,9 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
ignoring any bboxes.
ignore_wrt_candidates (bool): Whether to compute the iof between
`bboxes` and `gt_bboxes_ignore`, or the contrary.
gpu_assign_thr (int): The upper bound of the number of GT for GPU
assign. When the number of gt is above this threshold, will assign
on CPU device. Negative values mean not assign on CPU.
"""
def __init__(self,
......@@ -35,13 +38,15 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
min_pos_iou=.0,
gt_max_assign_all=True,
ignore_iof_thr=-1,
ignore_wrt_candidates=True):
ignore_wrt_candidates=True,
gpu_assign_thr=-1):
self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr
self.min_pos_iou = min_pos_iou
self.gt_max_assign_all = gt_max_assign_all
self.ignore_iof_thr = ignore_iof_thr
self.ignore_wrt_candidates = ignore_wrt_candidates
self.gpu_assign_thr = gpu_assign_thr
def assign(self,
approxs,
......@@ -90,6 +95,17 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
approxs = torch.transpose(
approxs.view(num_squares, approxs_per_octave, 4), 0,
1).contiguous().view(-1, 4)
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
num_gts > self.gpu_assign_thr) else False
# compute overlap and assign gt on CPU when number of GT is large
if assign_on_cpu:
device = approxs.device
approxs = approxs.cpu()
gt_bboxes = gt_bboxes.cpu()
if gt_bboxes_ignore is not None:
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
if gt_labels is not None:
gt_labels = gt_labels.cpu()
all_overlaps = bbox_overlaps(approxs, gt_bboxes)
overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
......@@ -111,4 +127,9 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
if assign_on_cpu:
assign_result.gt_inds = assign_result.gt_inds.to(device)
assign_result.max_overlaps = assign_result.max_overlaps.to(device)
if assign_result.labels is not None:
assign_result.labels = assign_result.labels.to(device)
return assign_result
......@@ -28,6 +28,9 @@ class MaxIoUAssigner(BaseAssigner):
ignoring any bboxes.
ignore_wrt_candidates (bool): Whether to compute the iof between
`bboxes` and `gt_bboxes_ignore`, or the contrary.
gpu_assign_thr (int): The upper bound of the number of GT for GPU
assign. When the number of gt is above this threshold, will assign
on CPU device. Negative values mean not assign on CPU.
"""
def __init__(self,
......@@ -36,13 +39,15 @@ class MaxIoUAssigner(BaseAssigner):
min_pos_iou=.0,
gt_max_assign_all=True,
ignore_iof_thr=-1,
ignore_wrt_candidates=True):
ignore_wrt_candidates=True,
gpu_assign_thr=-1):
self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr
self.min_pos_iou = min_pos_iou
self.gt_max_assign_all = gt_max_assign_all
self.ignore_iof_thr = ignore_iof_thr
self.ignore_wrt_candidates = ignore_wrt_candidates
self.gpu_assign_thr = gpu_assign_thr
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to bboxes.
......@@ -72,6 +77,17 @@ class MaxIoUAssigner(BaseAssigner):
"""
if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
gt_bboxes.shape[0] > self.gpu_assign_thr) else False
# compute overlap and assign gt on CPU when number of GT is large
if assign_on_cpu:
device = bboxes.device
bboxes = bboxes.cpu()
gt_bboxes = gt_bboxes.cpu()
if gt_bboxes_ignore is not None:
gt_bboxes_ignore = gt_bboxes_ignore.cpu()
if gt_labels is not None:
gt_labels = gt_labels.cpu()
bboxes = bboxes[:, :4]
overlaps = bbox_overlaps(gt_bboxes, bboxes)
......@@ -88,6 +104,11 @@ class MaxIoUAssigner(BaseAssigner):
overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
if assign_on_cpu:
assign_result.gt_inds = assign_result.gt_inds.to(device)
assign_result.max_overlaps = assign_result.max_overlaps.to(device)
if assign_result.labels is not None:
assign_result.labels = assign_result.labels.to(device)
return assign_result
def assign_wrt_overlaps(self, overlaps, gt_labels=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment