From e49171303033e14dbde068dbf39046afe357629b Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Sat, 22 Jun 2019 14:59:28 +0800 Subject: [PATCH] add reduction_override to BoundedIoULoss (#850) --- mmdet/models/losses/iou_loss.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py index 967e576..011ff36 100644 --- a/mmdet/models/losses/iou_loss.py +++ b/mmdet/models/losses/iou_loss.py @@ -111,16 +111,25 @@ class BoundedIoULoss(nn.Module): self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): if weight is not None and not torch.any(weight > 0): return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) loss = self.loss_weight * bounded_iou_loss( pred, target, weight, beta=self.beta, eps=self.eps, - reduction=self.reduction, + reduction=reduction, avg_factor=avg_factor, **kwargs) return loss -- GitLab