diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py
index b7bab125f1b0dc8b0196e31b97116bdb134b135a..fe10b866456ed8bc88910733ba0b14d18d817af6 100644
--- a/mmdet/models/losses/cross_entropy_loss.py
+++ b/mmdet/models/losses/cross_entropy_loss.py
@@ -2,10 +2,21 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .utils import weight_reduce_loss, weighted_loss
+from .utils import weight_reduce_loss
 from ..registry import LOSSES
 
-cross_entropy = weighted_loss(F.cross_entropy)
+
+def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
+    # element-wise losses
+    loss = F.cross_entropy(pred, label, reduction='none')
+
+    # apply weights and do the reduction
+    if weight is not None:
+        weight = weight.float()
+    loss = weight_reduce_loss(
+        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+    return loss
 
 
 def _expand_binary_labels(labels, label_weights, label_channels):