From 32df98e970c8848d3d9fd72492aba8bc030377d8 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Thu, 20 Jun 2019 22:25:42 +0800
Subject: [PATCH] Add reduction_override flag (#839)

* add reduction_override flag

* change default value of reduction_override as None

* add assertion, fix format

* delete redudant statement in util

* delete redudant comment
---
 mmdet/core/bbox/samplers/ohem_sampler.py  |  2 +-
 mmdet/models/bbox_heads/bbox_head.py      | 11 ++++++++---
 mmdet/models/losses/balanced_l1_loss.py   | 13 +++++++++++--
 mmdet/models/losses/cross_entropy_loss.py | 12 ++++++++++--
 mmdet/models/losses/focal_loss.py         | 12 ++++++++++--
 mmdet/models/losses/ghm_loss.py           |  3 +++
 mmdet/models/losses/iou_loss.py           | 13 +++++++++++--
 mmdet/models/losses/smooth_l1_loss.py     | 13 +++++++++++--
 mmdet/models/losses/utils.py              | 11 ++++++-----
 9 files changed, 71 insertions(+), 19 deletions(-)

diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py
index 800a1c2..0711d97 100644
--- a/mmdet/core/bbox/samplers/ohem_sampler.py
+++ b/mmdet/core/bbox/samplers/ohem_sampler.py
@@ -36,7 +36,7 @@ class OHEMSampler(BaseSampler):
                 label_weights=cls_score.new_ones(cls_score.size(0)),
                 bbox_targets=None,
                 bbox_weights=None,
-                reduce=False)['loss_cls']
+                reduction_override='none')['loss_cls']
             _, topk_loss_inds = loss.topk(num_expected)
         return inds[topk_loss_inds]
 
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index c67ea8a..436592c 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -97,12 +97,16 @@ class BBoxHead(nn.Module):
              label_weights,
              bbox_targets,
              bbox_weights,
-             reduce=True):
+             reduction_override=None):
         losses = dict()
         if cls_score is not None:
             avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
             losses['loss_cls'] = self.loss_cls(
-                cls_score, labels, label_weights, avg_factor=avg_factor)
+                cls_score,
+                labels,
+                label_weights,
+                avg_factor=avg_factor,
+                reduction_override=reduction_override)
             losses['acc'] = accuracy(cls_score, labels)
         if bbox_pred is not None:
             pos_inds = labels > 0
@@ -115,7 +119,8 @@ class BBoxHead(nn.Module):
                 pos_bbox_pred,
                 bbox_targets[pos_inds],
                 bbox_weights[pos_inds],
-                avg_factor=bbox_targets.size(0))
+                avg_factor=bbox_targets.size(0),
+                reduction_override=reduction_override)
         return losses
 
     def get_det_bboxes(self,
diff --git a/mmdet/models/losses/balanced_l1_loss.py b/mmdet/models/losses/balanced_l1_loss.py
index 2dee674..8593396 100644
--- a/mmdet/models/losses/balanced_l1_loss.py
+++ b/mmdet/models/losses/balanced_l1_loss.py
@@ -46,7 +46,16 @@ class BalancedL1Loss(nn.Module):
         self.reduction = reduction
         self.loss_weight = loss_weight
 
-    def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
+    def forward(self,
+                pred,
+                target,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
         loss_bbox = self.loss_weight * balanced_l1_loss(
             pred,
             target,
@@ -54,7 +63,7 @@ class BalancedL1Loss(nn.Module):
             alpha=self.alpha,
             gamma=self.gamma,
             beta=self.beta,
-            reduction=self.reduction,
+            reduction=reduction,
             avg_factor=avg_factor,
             **kwargs)
         return loss_bbox
diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py
index 1921978..2f2ce69 100644
--- a/mmdet/models/losses/cross_entropy_loss.py
+++ b/mmdet/models/losses/cross_entropy_loss.py
@@ -73,13 +73,21 @@ class CrossEntropyLoss(nn.Module):
         else:
             self.cls_criterion = cross_entropy
 
-    def forward(self, cls_score, label, weight=None, avg_factor=None,
+    def forward(self,
+                cls_score,
+                label,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
                 **kwargs):
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
         loss_cls = self.loss_weight * self.cls_criterion(
             cls_score,
             label,
             weight,
-            reduction=self.reduction,
+            reduction=reduction,
             avg_factor=avg_factor,
             **kwargs)
         return loss_cls
diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py
index b8ccfa0..7a46356 100644
--- a/mmdet/models/losses/focal_loss.py
+++ b/mmdet/models/losses/focal_loss.py
@@ -59,7 +59,15 @@ class FocalLoss(nn.Module):
         self.reduction = reduction
         self.loss_weight = loss_weight
 
-    def forward(self, pred, target, weight=None, avg_factor=None):
+    def forward(self,
+                pred,
+                target,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None):
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
         if self.use_sigmoid:
             loss_cls = self.loss_weight * sigmoid_focal_loss(
                 pred,
@@ -67,7 +75,7 @@ class FocalLoss(nn.Module):
                 weight,
                 gamma=self.gamma,
                 alpha=self.alpha,
-                reduction=self.reduction,
+                reduction=reduction,
                 avg_factor=avg_factor)
         else:
             raise NotImplementedError
diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py
index 7beeb47..95656a2 100644
--- a/mmdet/models/losses/ghm_loss.py
+++ b/mmdet/models/losses/ghm_loss.py
@@ -15,6 +15,7 @@ def _expand_binary_labels(labels, label_weights, label_channels):
     return bin_labels, bin_label_weights
 
 
+# TODO: code refactoring to make it consistent with other losses
 @LOSSES.register_module
 class GHMC(nn.Module):
     """GHM Classification Loss.
@@ -90,6 +91,7 @@ class GHMC(nn.Module):
         return loss * self.loss_weight
 
 
+# TODO: code refactoring to make it consistent with other losses
 @LOSSES.register_module
 class GHMR(nn.Module):
     """GHM Regression Loss.
@@ -116,6 +118,7 @@ class GHMR(nn.Module):
             self.acc_sum = torch.zeros(bins).cuda()
         self.loss_weight = loss_weight
 
+    # TODO: support reduction parameter
     def forward(self, pred, target, label_weight, avg_factor=None):
         """Calculate the GHM-R loss.
 
diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py
index 7c235cd..967e576 100644
--- a/mmdet/models/losses/iou_loss.py
+++ b/mmdet/models/losses/iou_loss.py
@@ -78,15 +78,24 @@ class IoULoss(nn.Module):
         self.reduction = reduction
         self.loss_weight = loss_weight
 
-    def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
+    def forward(self,
+                pred,
+                target,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
         if weight is not None and not torch.any(weight > 0):
             return (pred * weight).sum()  # 0
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
         loss = self.loss_weight * iou_loss(
             pred,
             target,
             weight,
             eps=self.eps,
-            reduction=self.reduction,
+            reduction=reduction,
             avg_factor=avg_factor,
             **kwargs)
         return loss
diff --git a/mmdet/models/losses/smooth_l1_loss.py b/mmdet/models/losses/smooth_l1_loss.py
index 6a098fc..75d71e8 100644
--- a/mmdet/models/losses/smooth_l1_loss.py
+++ b/mmdet/models/losses/smooth_l1_loss.py
@@ -24,13 +24,22 @@ class SmoothL1Loss(nn.Module):
         self.reduction = reduction
         self.loss_weight = loss_weight
 
-    def forward(self, pred, target, weight=None, avg_factor=None, **kwargs):
+    def forward(self,
+                pred,
+                target,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
         loss_bbox = self.loss_weight * smooth_l1_loss(
             pred,
             target,
             weight,
             beta=self.beta,
-            reduction=self.reduction,
+            reduction=reduction,
             avg_factor=avg_factor,
             **kwargs)
         return loss_bbox
diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py
index b902c64..5c16e06 100644
--- a/mmdet/models/losses/utils.py
+++ b/mmdet/models/losses/utils.py
@@ -42,12 +42,13 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
     # if avg_factor is not specified, just reduce the loss
     if avg_factor is None:
         loss = reduce_loss(loss, reduction)
-    # otherwise average the loss by avg_factor
     else:
-        if reduction != 'mean':
-            raise ValueError(
-                'avg_factor can only be used with reduction="mean"')
-        loss = loss.sum() / avg_factor
+        # if reduction is mean, then average the loss by avg_factor
+        if reduction == 'mean':
+            loss = loss.sum() / avg_factor
+        # if reduction is 'none', then do nothing, otherwise raise an error
+        elif reduction != 'none':
+            raise ValueError('avg_factor can not be used with reduction="sum"')
     return loss
 
 
-- 
GitLab