diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py
index abb3047c5204cce04da25b5ad3d654185a9323a2..2dae8315189afc46011e3fcd39eb722cc220a695 100644
--- a/mmdet/core/anchor/anchor_target.py
+++ b/mmdet/core/anchor/anchor_target.py
@@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors,
 
 
 def expand_binary_labels(labels, label_weights, label_channels):
-    bin_labels = labels.new_full(
-        (labels.size(0), label_channels), 0, dtype=torch.float32)
+    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
     inds = torch.nonzero(labels >= 1).squeeze()
     if inds.numel() > 0:
         bin_labels[inds, labels[inds] - 1] = 1
diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py
index bf9a44d5e58b0b0ff414cec94efeae1c73e0105d..560dac43ffd6c90be0da82ddc9d8a78cf2aba129 100644
--- a/mmdet/core/loss/losses.py
+++ b/mmdet/core/loss/losses.py
@@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
     return torch.sum(raw * weight)[None] / avg_factor
 
 
-def weighted_cross_entropy(pred, label, weight, avg_factor=None,
-                           reduce=True):
+def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
     if avg_factor is None:
         avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
     raw = F.cross_entropy(pred, label, reduction='none')
@@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred,
                        alpha=0.25,
                        reduction='mean'):
     pred_sigmoid = pred.sigmoid()
+    target = target.type_as(pred)
     pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
     weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
     weight = weight * pt.pow(gamma)