diff --git a/mmdet/core/losses/losses.py b/mmdet/core/losses/losses.py index d0e642f807c94844d4442c8ef119e0a11ec2820f..14b49f5cb90ccc29240622a0c2a6764ae4c68520 100644 --- a/mmdet/core/losses/losses.py +++ b/mmdet/core/losses/losses.py @@ -36,7 +36,7 @@ def sigmoid_focal_loss(pred, weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = weight * pt.pow(gamma) return F.binary_cross_entropy_with_logits( - pred, target, weight, size_average=reduction) + pred, target, weight, reduction=reduction) def weighted_sigmoid_focal_loss(pred, @@ -61,16 +61,6 @@ def mask_cross_entropy(pred, target, label): pred_slice, target, reduction='elementwise_mean')[None] -def weighted_mask_cross_entropy(pred, target, weight, label): - num_rois = pred.size()[0] - num_samples = torch.sum(weight > 0).float().item() + 1e-6 - assert num_samples >= 1 - inds = torch.arange(0, num_rois).long().cuda() - pred_slice = pred[inds, label].squeeze(1) - return F.binary_cross_entropy_with_logits( - pred_slice, target, weight, size_average=False)[None] / num_samples - - def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'): assert beta > 0 assert pred.size() == target.size() and target.numel() > 0