From 4a0d7add98ea010aa1fd4a70981987b1d501656e Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Fri, 21 Jun 2019 22:57:23 +0800 Subject: [PATCH] Fix bug of ce loss when reduction != mean (#848) * fix bug of ce loss when reduction != mean * change function order * modify comment * minor fix --- mmdet/models/losses/cross_entropy_loss.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index b7bab12..fe10b86 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -2,10 +2,21 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .utils import weight_reduce_loss, weighted_loss +from .utils import weight_reduce_loss from ..registry import LOSSES -cross_entropy = weighted_loss(F.cross_entropy) + +def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None): + # element-wise losses + loss = F.cross_entropy(pred, label, reduction='none') + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss def _expand_binary_labels(labels, label_weights, label_channels): -- GitLab