From d5d25be2d480b33a12bc7284107ec77c2eb6975b Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Mon, 15 Apr 2019 11:25:57 +0800 Subject: [PATCH] fix bug in two stage and weighted_binary_cross_entropy_loss --- mmdet/core/anchor/anchor_target.py | 10 ---------- mmdet/core/loss/losses.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py index 2648977..60c902e 100644 --- a/mmdet/core/anchor/anchor_target.py +++ b/mmdet/core/anchor/anchor_target.py @@ -159,16 +159,6 @@ def anchor_target_single(flat_anchors, neg_inds) -def expand_binary_labels(labels, label_weights, label_channels): - bin_labels = labels.new_full((labels.size(0), label_channels), 0) - inds = torch.nonzero(labels >= 1).squeeze() - if inds.numel() > 0: - bin_labels[inds, labels[inds] - 1] = 1 - bin_label_weights = label_weights.view(-1, 1).expand( - label_weights.size(0), label_channels) - return bin_labels, bin_label_weights - - def anchor_inside_flags(flat_anchors, valid_flags, img_shape, allowed_border=0): img_h, img_w = img_shape[:2] diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py index 922e058..1c5bf70 100644 --- a/mmdet/core/loss/losses.py +++ b/mmdet/core/loss/losses.py @@ -23,6 +23,8 @@ def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True): def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None): + if pred.dim() != label.dim(): + label, weight = _expand_binary_labels(label, weight, pred.size(-1)) if avg_factor is None: avg_factor = max(torch.sum(weight > 0).float().item(), 1.) return F.binary_cross_entropy_with_logits( @@ -115,3 +117,13 @@ def accuracy(pred, target, topk=1): correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / pred.size(0))) return res[0] if return_single else res + + +def _expand_binary_labels(labels, label_weights, label_channels): + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + inds = torch.nonzero(labels >= 1).squeeze() + if inds.numel() > 0: + bin_labels[inds, labels[inds] - 1] = 1 + bin_label_weights = label_weights.view(-1, 1).expand( + label_weights.size(0), label_channels) + return bin_labels, bin_label_weights -- GitLab