Skip to content
Snippets Groups Projects
Commit 4a0d7add authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Fix bug of ce loss when reduction != mean (#848)

* fix bug of ce loss when reduction != mean

* change function order

* modify comment

* minor fix
parent f724f9ac
No related branches found
No related tags found
No related merge requests found
...@@ -2,10 +2,21 @@ import torch ...@@ -2,10 +2,21 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .utils import weight_reduce_loss, weighted_loss from .utils import weight_reduce_loss
from ..registry import LOSSES from ..registry import LOSSES
cross_entropy = weighted_loss(F.cross_entropy)
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
# element-wise losses
loss = F.cross_entropy(pred, label, reduction='none')
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_binary_labels(labels, label_weights, label_channels): def _expand_binary_labels(labels, label_weights, label_channels):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment