diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py index abb3047c5204cce04da25b5ad3d654185a9323a2..2dae8315189afc46011e3fcd39eb722cc220a695 100644 --- a/mmdet/core/anchor/anchor_target.py +++ b/mmdet/core/anchor/anchor_target.py @@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors, def expand_binary_labels(labels, label_weights, label_channels): - bin_labels = labels.new_full( - (labels.size(0), label_channels), 0, dtype=torch.float32) + bin_labels = labels.new_full((labels.size(0), label_channels), 0) inds = torch.nonzero(labels >= 1).squeeze() if inds.numel() > 0: bin_labels[inds, labels[inds] - 1] = 1 diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py index bf9a44d5e58b0b0ff414cec94efeae1c73e0105d..560dac43ffd6c90be0da82ddc9d8a78cf2aba129 100644 --- a/mmdet/core/loss/losses.py +++ b/mmdet/core/loss/losses.py @@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None): return torch.sum(raw * weight)[None] / avg_factor -def weighted_cross_entropy(pred, label, weight, avg_factor=None, - reduce=True): +def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True): if avg_factor is None: avg_factor = max(torch.sum(weight > 0).float().item(), 1.) raw = F.cross_entropy(pred, label, reduction='none') @@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred, alpha=0.25, reduction='mean'): pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = weight * pt.pow(gamma)