diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py
index 967e5764fdc1f7ed43bc17576a79947fe7fcfe50..011ff3620fe9562f57124970b323386493221319 100644
--- a/mmdet/models/losses/iou_loss.py
+++ b/mmdet/models/losses/iou_loss.py
@@ -111,16 +111,25 @@ class BoundedIoULoss(nn.Module):
         self.reduction = reduction
         self.loss_weight = loss_weight
 
-    def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
+    def forward(self,
+                pred,
+                target,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
         if weight is not None and not torch.any(weight > 0):
             return (pred * weight).sum()  # 0
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
         loss = self.loss_weight * bounded_iou_loss(
             pred,
             target,
             weight,
             beta=self.beta,
             eps=self.eps,
-            reduction=self.reduction,
+            reduction=reduction,
             avg_factor=avg_factor,
             **kwargs)
         return loss