diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index b7bab125f1b0dc8b0196e31b97116bdb134b135a..fe10b866456ed8bc88910733ba0b14d18d817af6 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -2,10 +2,21 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .utils import weight_reduce_loss, weighted_loss +from .utils import weight_reduce_loss from ..registry import LOSSES -cross_entropy = weighted_loss(F.cross_entropy) + +def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None): + # element-wise losses + loss = F.cross_entropy(pred, label, reduction='none') + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss def _expand_binary_labels(labels, label_weights, label_channels):