diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py index 967e5764fdc1f7ed43bc17576a79947fe7fcfe50..011ff3620fe9562f57124970b323386493221319 100644 --- a/mmdet/models/losses/iou_loss.py +++ b/mmdet/models/losses/iou_loss.py @@ -111,16 +111,25 @@ class BoundedIoULoss(nn.Module): self.reduction = reduction self.loss_weight = loss_weight - def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): if weight is not None and not torch.any(weight > 0): return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) loss = self.loss_weight * bounded_iou_loss( pred, target, weight, beta=self.beta, eps=self.eps, - reduction=self.reduction, + reduction=reduction, avg_factor=avg_factor, **kwargs) return loss