Skip to content
Snippets Groups Projects
Commit d5d25be2 authored by Cao Yuhang's avatar Cao Yuhang
Browse files

fix bug in two stage and weighted_binary_cross_entropy_loss

parent d67a2e16
No related branches found
No related tags found
No related merge requests found
...@@ -159,16 +159,6 @@ def anchor_target_single(flat_anchors, ...@@ -159,16 +159,6 @@ def anchor_target_single(flat_anchors,
neg_inds) neg_inds)
def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
def anchor_inside_flags(flat_anchors, valid_flags, img_shape, def anchor_inside_flags(flat_anchors, valid_flags, img_shape,
allowed_border=0): allowed_border=0):
img_h, img_w = img_shape[:2] img_h, img_w = img_shape[:2]
......
...@@ -23,6 +23,8 @@ def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True): ...@@ -23,6 +23,8 @@ def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None): def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
if avg_factor is None: if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.) avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
return F.binary_cross_entropy_with_logits( return F.binary_cross_entropy_with_logits(
...@@ -115,3 +117,13 @@ def accuracy(pred, target, topk=1): ...@@ -115,3 +117,13 @@ def accuracy(pred, target, topk=1):
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0))) res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res return res[0] if return_single else res
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment