Skip to content
Snippets Groups Projects
Commit a772bea5 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'dev' into mask-debug

parents bb6ef3b3 cedd8a82
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,7 @@ def sigmoid_focal_loss(pred, ...@@ -36,7 +36,7 @@ def sigmoid_focal_loss(pred,
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma) weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits( return F.binary_cross_entropy_with_logits(
pred, target, weight, size_average=reduction) pred, target, weight, reduction=reduction)
def weighted_sigmoid_focal_loss(pred, def weighted_sigmoid_focal_loss(pred,
...@@ -61,16 +61,6 @@ def mask_cross_entropy(pred, target, label): ...@@ -61,16 +61,6 @@ def mask_cross_entropy(pred, target, label):
pred_slice, target, reduction='elementwise_mean')[None] pred_slice, target, reduction='elementwise_mean')[None]
def weighted_mask_cross_entropy(pred, target, weight, label):
num_rois = pred.size()[0]
num_samples = torch.sum(weight > 0).float().item() + 1e-6
assert num_samples >= 1
inds = torch.arange(0, num_rois).long().cuda()
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight, size_average=False)[None] / num_samples
def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'): def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'):
assert beta > 0 assert beta > 0
assert pred.size() == target.size() and target.numel() > 0 assert pred.size() == target.size() and target.numel() > 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment