Skip to content
Snippets Groups Projects
Unverified Commit f2cfa86b authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #310 from hellock/master

Bug fix for retinanet with 2 classes (fg/bg)
parents ba73bcc5 d1cf5e59
No related branches found
No related tags found
No related merge requests found
......@@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors,
def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full(
(labels.size(0), label_channels), 0, dtype=torch.float32)
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
......
......@@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return torch.sum(raw * weight)[None] / avg_factor
def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduce=True):
def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none')
......@@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred,
alpha=0.25,
reduction='mean'):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment