diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py index d059c65a509ab18d2e13e5f4592bc345c1cba5df..54096a4f814da76788b88f6e0c0d2355485a2957 100644 --- a/mmdet/models/anchor_heads/anchor_head.py +++ b/mmdet/models/anchor_heads/anchor_head.py @@ -25,9 +25,9 @@ class AnchorHead(nn.Module): anchor_base_sizes (Iterable): Anchor base sizes. target_means (Iterable): Mean values of regression targets. target_stds (Iterable): Std values of regression targets. - use_sigmoid_cls (bool): Whether to use sigmoid loss for classification. - (softmax by default) - use_focal_loss (bool): Whether to use focal loss for classification. + use_sigmoid_cls (bool): Whether to use sigmoid loss for + classification. (softmax by default) + cls_focal_loss (bool): Whether to use focal loss for classification. """ # noqa: W605 def __init__(self, @@ -41,7 +41,7 @@ class AnchorHead(nn.Module): target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0), use_sigmoid_cls=False, - use_focal_loss=False): + cls_focal_loss=False): super(AnchorHead, self).__init__() self.in_channels = in_channels self.num_classes = num_classes @@ -54,7 +54,7 @@ class AnchorHead(nn.Module): self.target_means = target_means self.target_stds = target_stds self.use_sigmoid_cls = use_sigmoid_cls - self.use_focal_loss = use_focal_loss + self.cls_focal_loss = cls_focal_loss self.anchor_generators = [] for anchor_base in self.anchor_base_sizes: @@ -133,16 +133,16 @@ class AnchorHead(nn.Module): cls_score = cls_score.permute(0, 2, 3, 1).reshape( -1, self.cls_out_channels) if self.use_sigmoid_cls: - if self.use_focal_loss: + if self.cls_focal_loss: cls_criterion = weighted_sigmoid_focal_loss else: cls_criterion = weighted_binary_cross_entropy else: - if self.use_focal_loss: + if self.cls_focal_loss: raise NotImplementedError else: cls_criterion = weighted_cross_entropy - if self.use_focal_loss: + if self.cls_focal_loss: loss_cls = cls_criterion( cls_score, labels, @@ -178,7 +178,7 @@ class AnchorHead(nn.Module): anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_metas) - sampling = False if self.use_focal_loss else True + sampling = False if self.cls_focal_loss else True label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = anchor_target( anchor_list, @@ -196,7 +196,7 @@ class AnchorHead(nn.Module): return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets - num_total_samples = (num_total_pos if self.use_focal_loss else + num_total_samples = (num_total_pos if self.cls_focal_loss else num_total_pos + num_total_neg) losses_cls, losses_reg = multi_apply( self.loss_single, diff --git a/mmdet/models/anchor_heads/retina_head.py b/mmdet/models/anchor_heads/retina_head.py index 815e8f0f91601588e31b9ce16a821c4e6aa57873..3eefca4a13744b290832f7c5f3024dd4113112e9 100644 --- a/mmdet/models/anchor_heads/retina_head.py +++ b/mmdet/models/anchor_heads/retina_head.py @@ -32,7 +32,7 @@ class RetinaHead(AnchorHead): in_channels, anchor_scales=anchor_scales, use_sigmoid_cls=True, - use_focal_loss=True, + cls_focal_loss=True, **kwargs) def _init_layers(self): diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py index 1f704f74cdc8d527bb227a8bc502eab9ae872061..9c8b2a11e5592cefc873d22c8ab88c54c3fb127a 100644 --- a/mmdet/models/anchor_heads/ssd_head.py +++ b/mmdet/models/anchor_heads/ssd_head.py @@ -90,7 +90,7 @@ class SSDHead(AnchorHead): self.target_means = target_means self.target_stds = target_stds self.use_sigmoid_cls = False - self.use_focal_loss = False + self.cls_focal_loss = False def init_weights(self): for m in self.modules():