diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 2f2ce69c4c8e8d77720531b0e8a149aec1e1fde6..b7bab125f1b0dc8b0196e31b97116bdb134b135a 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -29,14 +29,13 @@ def binary_cross_entropy(pred, if pred.dim() != label.dim(): label, weight = _expand_binary_labels(label, weight, pred.size(-1)) - # element-wise losses + # weighted element-wise losses if weight is not None: weight = weight.float() loss = F.binary_cross_entropy_with_logits( pred, label.float(), weight, reduction='none') - # apply weights and do the reduction - loss = weight_reduce_loss( - loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + # do the reduction for the weighted loss + loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor) return loss