diff --git a/mmdet/core/bbox/bbox_target.py b/mmdet/core/bbox/bbox_target.py index 4a0450d915784ec1f5bd67963bbb27a743e71044..aa1fbc67430672185a1a01cbc5338a1912928b84 100644 --- a/mmdet/core/bbox/bbox_target.py +++ b/mmdet/core/bbox/bbox_target.py @@ -57,9 +57,6 @@ def bbox_target_single(pos_bboxes, bbox_weights[:num_pos, :] = 1 if num_neg > 0: label_weights[-num_neg:] = 1.0 - if reg_classes > 1: - bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights, - labels, reg_classes) return labels, label_weights, bbox_targets, bbox_weights diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index 2168e2e156814dbfd875915335ce2255a9df6c19..092a8121d393077b0593efed7bdf3500982adc6f 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F - from mmdet.core import (delta2bbox, multiclass_nms, bbox_target, weighted_cross_entropy, weighted_smoothl1, accuracy) + from ..registry import HEADS @@ -94,10 +94,16 @@ class BBoxHead(nn.Module): cls_score, labels, label_weights, reduce=reduce) losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: + pos_mask = labels > 0 + if self.reg_class_agnostic: + pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask] + else: + pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, + 4)[pos_mask, labels[pos_mask]] losses['loss_reg'] = weighted_smoothl1( - bbox_pred, - bbox_targets, - bbox_weights, + pos_bbox_pred, + bbox_targets[pos_mask], + bbox_weights[pos_mask], avg_factor=bbox_targets.size(0)) return losses