diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py
index 264897713556b202ba25705bf6af54f3793ab67e..60c902eaa86575f5ea3632f4ffea1c65e4dc1c0e 100644
--- a/mmdet/core/anchor/anchor_target.py
+++ b/mmdet/core/anchor/anchor_target.py
@@ -159,16 +159,6 @@ def anchor_target_single(flat_anchors,
             neg_inds)
 
 
-def expand_binary_labels(labels, label_weights, label_channels):
-    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
-    inds = torch.nonzero(labels >= 1).squeeze()
-    if inds.numel() > 0:
-        bin_labels[inds, labels[inds] - 1] = 1
-    bin_label_weights = label_weights.view(-1, 1).expand(
-        label_weights.size(0), label_channels)
-    return bin_labels, bin_label_weights
-
-
 def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
                         allowed_border=0):
     img_h, img_w = img_shape[:2]
diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py
index 922e058a6ed2328d29887f30dd0504d03cf35550..1c5bf70051b16d86c3b0027f3277e9167931ca02 100644
--- a/mmdet/core/loss/losses.py
+++ b/mmdet/core/loss/losses.py
@@ -23,6 +23,8 @@ def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
 
 
 def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
+    if pred.dim() != label.dim():
+        label, weight = _expand_binary_labels(label, weight, pred.size(-1))
     if avg_factor is None:
         avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
     return F.binary_cross_entropy_with_logits(
@@ -115,3 +117,13 @@ def accuracy(pred, target, topk=1):
         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
         res.append(correct_k.mul_(100.0 / pred.size(0)))
     return res[0] if return_single else res
+
+
+def _expand_binary_labels(labels, label_weights, label_channels):
+    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
+    inds = torch.nonzero(labels >= 1).squeeze()
+    if inds.numel() > 0:
+        bin_labels[inds, labels[inds] - 1] = 1
+    bin_label_weights = label_weights.view(-1, 1).expand(
+        label_weights.size(0), label_channels)
+    return bin_labels, bin_label_weights