Skip to content
Snippets Groups Projects
Commit e4917130 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

add reduction_override to BoundedIoULoss (#850)

parent 4a0d7add
No related branches found
No related tags found
No related merge requests found
......@@ -111,16 +111,25 @@ class BoundedIoULoss(nn.Module):
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
if weight is not None and not torch.any(weight > 0):
return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * bounded_iou_loss(
pred,
target,
weight,
beta=self.beta,
eps=self.eps,
reduction=self.reduction,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment