From bd8fd27eb2949ce4f8bf22bd765cc1d57f785543 Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Sun, 14 Apr 2019 16:01:06 +0800 Subject: [PATCH] rename pos_mask to pos_inds --- mmdet/models/bbox_heads/bbox_head.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index 092a812..4dcbd97 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F + from mmdet.core import (delta2bbox, multiclass_nms, bbox_target, weighted_cross_entropy, weighted_smoothl1, accuracy) - from ..registry import HEADS @@ -94,16 +94,16 @@ class BBoxHead(nn.Module): cls_score, labels, label_weights, reduce=reduce) losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: - pos_mask = labels > 0 + pos_inds = labels > 0 if self.reg_class_agnostic: - pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask] + pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds] else: pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, - 4)[pos_mask, labels[pos_mask]] + 4)[pos_inds, labels[pos_inds]] losses['loss_reg'] = weighted_smoothl1( pos_bbox_pred, - bbox_targets[pos_mask], - bbox_weights[pos_mask], + bbox_targets[pos_inds], + bbox_weights[pos_inds], avg_factor=bbox_targets.size(0)) return losses -- GitLab