From 1ec92ef4094c239ae138ace52787a2496e4cb8df Mon Sep 17 00:00:00 2001
From: myownskyW7 <1155098160@link.cuhk.edu.hk>
Date: Mon, 13 May 2019 17:21:29 +0800
Subject: [PATCH] rename use_focal_loss -> cls_focal_loss (#639)

* use_sigmoid_cls -> cls_sigmoid_loss; use_focal_loss -> cls_focal_loss

* fix flake8 error

* cls_sigmoid_loss - > use_sigmoid_cls
---
 mmdet/models/anchor_heads/anchor_head.py | 20 ++++++++++----------
 mmdet/models/anchor_heads/retina_head.py |  2 +-
 mmdet/models/anchor_heads/ssd_head.py    |  2 +-
 3 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index d059c65..54096a4 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -25,9 +25,9 @@ class AnchorHead(nn.Module):
         anchor_base_sizes (Iterable): Anchor base sizes.
         target_means (Iterable): Mean values of regression targets.
         target_stds (Iterable): Std values of regression targets.
-        use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
-            (softmax by default)
-        use_focal_loss (bool): Whether to use focal loss for classification.
+        use_sigmoid_cls (bool): Whether to use sigmoid loss for
+            classification. (softmax by default)
+        cls_focal_loss (bool): Whether to use focal loss for classification.
     """  # noqa: W605
 
     def __init__(self,
@@ -41,7 +41,7 @@ class AnchorHead(nn.Module):
                  target_means=(.0, .0, .0, .0),
                  target_stds=(1.0, 1.0, 1.0, 1.0),
                  use_sigmoid_cls=False,
-                 use_focal_loss=False):
+                 cls_focal_loss=False):
         super(AnchorHead, self).__init__()
         self.in_channels = in_channels
         self.num_classes = num_classes
@@ -54,7 +54,7 @@ class AnchorHead(nn.Module):
         self.target_means = target_means
         self.target_stds = target_stds
         self.use_sigmoid_cls = use_sigmoid_cls
-        self.use_focal_loss = use_focal_loss
+        self.cls_focal_loss = cls_focal_loss
 
         self.anchor_generators = []
         for anchor_base in self.anchor_base_sizes:
@@ -133,16 +133,16 @@ class AnchorHead(nn.Module):
         cls_score = cls_score.permute(0, 2, 3, 1).reshape(
             -1, self.cls_out_channels)
         if self.use_sigmoid_cls:
-            if self.use_focal_loss:
+            if self.cls_focal_loss:
                 cls_criterion = weighted_sigmoid_focal_loss
             else:
                 cls_criterion = weighted_binary_cross_entropy
         else:
-            if self.use_focal_loss:
+            if self.cls_focal_loss:
                 raise NotImplementedError
             else:
                 cls_criterion = weighted_cross_entropy
-        if self.use_focal_loss:
+        if self.cls_focal_loss:
             loss_cls = cls_criterion(
                 cls_score,
                 labels,
@@ -178,7 +178,7 @@ class AnchorHead(nn.Module):
 
         anchor_list, valid_flag_list = self.get_anchors(
             featmap_sizes, img_metas)
-        sampling = False if self.use_focal_loss else True
+        sampling = False if self.cls_focal_loss else True
         label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
         cls_reg_targets = anchor_target(
             anchor_list,
@@ -196,7 +196,7 @@ class AnchorHead(nn.Module):
             return None
         (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
          num_total_pos, num_total_neg) = cls_reg_targets
-        num_total_samples = (num_total_pos if self.use_focal_loss else
+        num_total_samples = (num_total_pos if self.cls_focal_loss else
                              num_total_pos + num_total_neg)
         losses_cls, losses_reg = multi_apply(
             self.loss_single,
diff --git a/mmdet/models/anchor_heads/retina_head.py b/mmdet/models/anchor_heads/retina_head.py
index 815e8f0..3eefca4 100644
--- a/mmdet/models/anchor_heads/retina_head.py
+++ b/mmdet/models/anchor_heads/retina_head.py
@@ -32,7 +32,7 @@ class RetinaHead(AnchorHead):
             in_channels,
             anchor_scales=anchor_scales,
             use_sigmoid_cls=True,
-            use_focal_loss=True,
+            cls_focal_loss=True,
             **kwargs)
 
     def _init_layers(self):
diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py
index 1f704f7..9c8b2a1 100644
--- a/mmdet/models/anchor_heads/ssd_head.py
+++ b/mmdet/models/anchor_heads/ssd_head.py
@@ -90,7 +90,7 @@ class SSDHead(AnchorHead):
         self.target_means = target_means
         self.target_stds = target_stds
         self.use_sigmoid_cls = False
-        self.use_focal_loss = False
+        self.cls_focal_loss = False
 
     def init_weights(self):
         for m in self.modules():
-- 
GitLab