From e49171303033e14dbde068dbf39046afe357629b Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Sat, 22 Jun 2019 14:59:28 +0800
Subject: [PATCH] add reduction_override to BoundedIoULoss (#850)

---
 mmdet/models/losses/iou_loss.py | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py
index 967e576..011ff36 100644
--- a/mmdet/models/losses/iou_loss.py
+++ b/mmdet/models/losses/iou_loss.py
@@ -111,16 +111,25 @@ class BoundedIoULoss(nn.Module):
         self.reduction = reduction
         self.loss_weight = loss_weight
 
-    def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
+    def forward(self,
+                pred,
+                target,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
         if weight is not None and not torch.any(weight > 0):
             return (pred * weight).sum()  # 0
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
         loss = self.loss_weight * bounded_iou_loss(
             pred,
             target,
             weight,
             beta=self.beta,
             eps=self.eps,
-            reduction=self.reduction,
+            reduction=reduction,
             avg_factor=avg_factor,
             **kwargs)
         return loss
-- 
GitLab