Skip to content
Snippets Groups Projects
Commit 32df98e9 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Add reduction_override flag (#839)

* add reduction_override flag

* change default value of reduction_override as None

* add assertion, fix format

* delete redudant statement in util

* delete redudant comment
parent fc0172b4
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,7 @@ class OHEMSampler(BaseSampler): ...@@ -36,7 +36,7 @@ class OHEMSampler(BaseSampler):
label_weights=cls_score.new_ones(cls_score.size(0)), label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None, bbox_targets=None,
bbox_weights=None, bbox_weights=None,
reduce=False)['loss_cls'] reduction_override='none')['loss_cls']
_, topk_loss_inds = loss.topk(num_expected) _, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds] return inds[topk_loss_inds]
......
...@@ -97,12 +97,16 @@ class BBoxHead(nn.Module): ...@@ -97,12 +97,16 @@ class BBoxHead(nn.Module):
label_weights, label_weights,
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
reduce=True): reduction_override=None):
losses = dict() losses = dict()
if cls_score is not None: if cls_score is not None:
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
losses['loss_cls'] = self.loss_cls( losses['loss_cls'] = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor) cls_score,
labels,
label_weights,
avg_factor=avg_factor,
reduction_override=reduction_override)
losses['acc'] = accuracy(cls_score, labels) losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None: if bbox_pred is not None:
pos_inds = labels > 0 pos_inds = labels > 0
...@@ -115,7 +119,8 @@ class BBoxHead(nn.Module): ...@@ -115,7 +119,8 @@ class BBoxHead(nn.Module):
pos_bbox_pred, pos_bbox_pred,
bbox_targets[pos_inds], bbox_targets[pos_inds],
bbox_weights[pos_inds], bbox_weights[pos_inds],
avg_factor=bbox_targets.size(0)) avg_factor=bbox_targets.size(0),
reduction_override=reduction_override)
return losses return losses
def get_det_bboxes(self, def get_det_bboxes(self,
......
...@@ -46,7 +46,16 @@ class BalancedL1Loss(nn.Module): ...@@ -46,7 +46,16 @@ class BalancedL1Loss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * balanced_l1_loss( loss_bbox = self.loss_weight * balanced_l1_loss(
pred, pred,
target, target,
...@@ -54,7 +63,7 @@ class BalancedL1Loss(nn.Module): ...@@ -54,7 +63,7 @@ class BalancedL1Loss(nn.Module):
alpha=self.alpha, alpha=self.alpha,
gamma=self.gamma, gamma=self.gamma,
beta=self.beta, beta=self.beta,
reduction=self.reduction, reduction=reduction,
avg_factor=avg_factor, avg_factor=avg_factor,
**kwargs) **kwargs)
return loss_bbox return loss_bbox
...@@ -73,13 +73,21 @@ class CrossEntropyLoss(nn.Module): ...@@ -73,13 +73,21 @@ class CrossEntropyLoss(nn.Module):
else: else:
self.cls_criterion = cross_entropy self.cls_criterion = cross_entropy
def forward(self, cls_score, label, weight=None, avg_factor=None, def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs): **kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_cls = self.loss_weight * self.cls_criterion( loss_cls = self.loss_weight * self.cls_criterion(
cls_score, cls_score,
label, label,
weight, weight,
reduction=self.reduction, reduction=reduction,
avg_factor=avg_factor, avg_factor=avg_factor,
**kwargs) **kwargs)
return loss_cls return loss_cls
...@@ -59,7 +59,15 @@ class FocalLoss(nn.Module): ...@@ -59,7 +59,15 @@ class FocalLoss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None): def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid: if self.use_sigmoid:
loss_cls = self.loss_weight * sigmoid_focal_loss( loss_cls = self.loss_weight * sigmoid_focal_loss(
pred, pred,
...@@ -67,7 +75,7 @@ class FocalLoss(nn.Module): ...@@ -67,7 +75,7 @@ class FocalLoss(nn.Module):
weight, weight,
gamma=self.gamma, gamma=self.gamma,
alpha=self.alpha, alpha=self.alpha,
reduction=self.reduction, reduction=reduction,
avg_factor=avg_factor) avg_factor=avg_factor)
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -15,6 +15,7 @@ def _expand_binary_labels(labels, label_weights, label_channels): ...@@ -15,6 +15,7 @@ def _expand_binary_labels(labels, label_weights, label_channels):
return bin_labels, bin_label_weights return bin_labels, bin_label_weights
# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module @LOSSES.register_module
class GHMC(nn.Module): class GHMC(nn.Module):
"""GHM Classification Loss. """GHM Classification Loss.
...@@ -90,6 +91,7 @@ class GHMC(nn.Module): ...@@ -90,6 +91,7 @@ class GHMC(nn.Module):
return loss * self.loss_weight return loss * self.loss_weight
# TODO: code refactoring to make it consistent with other losses
@LOSSES.register_module @LOSSES.register_module
class GHMR(nn.Module): class GHMR(nn.Module):
"""GHM Regression Loss. """GHM Regression Loss.
...@@ -116,6 +118,7 @@ class GHMR(nn.Module): ...@@ -116,6 +118,7 @@ class GHMR(nn.Module):
self.acc_sum = torch.zeros(bins).cuda() self.acc_sum = torch.zeros(bins).cuda()
self.loss_weight = loss_weight self.loss_weight = loss_weight
# TODO: support reduction parameter
def forward(self, pred, target, label_weight, avg_factor=None): def forward(self, pred, target, label_weight, avg_factor=None):
"""Calculate the GHM-R loss. """Calculate the GHM-R loss.
......
...@@ -78,15 +78,24 @@ class IoULoss(nn.Module): ...@@ -78,15 +78,24 @@ class IoULoss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
if weight is not None and not torch.any(weight > 0): if weight is not None and not torch.any(weight > 0):
return (pred * weight).sum() # 0 return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * iou_loss( loss = self.loss_weight * iou_loss(
pred, pred,
target, target,
weight, weight,
eps=self.eps, eps=self.eps,
reduction=self.reduction, reduction=reduction,
avg_factor=avg_factor, avg_factor=avg_factor,
**kwargs) **kwargs)
return loss return loss
......
...@@ -24,13 +24,22 @@ class SmoothL1Loss(nn.Module): ...@@ -24,13 +24,22 @@ class SmoothL1Loss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
def forward(self, pred, target, weight=None, avg_factor=None, **kwargs): def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * smooth_l1_loss( loss_bbox = self.loss_weight * smooth_l1_loss(
pred, pred,
target, target,
weight, weight,
beta=self.beta, beta=self.beta,
reduction=self.reduction, reduction=reduction,
avg_factor=avg_factor, avg_factor=avg_factor,
**kwargs) **kwargs)
return loss_bbox return loss_bbox
...@@ -42,12 +42,13 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): ...@@ -42,12 +42,13 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
# if avg_factor is not specified, just reduce the loss # if avg_factor is not specified, just reduce the loss
if avg_factor is None: if avg_factor is None:
loss = reduce_loss(loss, reduction) loss = reduce_loss(loss, reduction)
# otherwise average the loss by avg_factor
else: else:
if reduction != 'mean': # if reduction is mean, then average the loss by avg_factor
raise ValueError( if reduction == 'mean':
'avg_factor can only be used with reduction="mean"') loss = loss.sum() / avg_factor
loss = loss.sum() / avg_factor # if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss return loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment