From 64928acc89bfe7e24a5ce948baf98c5a9186b27b Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Mon, 13 May 2019 01:16:54 -0700
Subject: [PATCH] Rename normalize to norm_cfg (#637)

* rename normalize to norm_cfg

* update configs

* Update resnet.py
---
 configs/cascade_mask_rcnn_r50_c4_1x.py        |  5 +-
 configs/cascade_rcnn_r50_c4_1x.py             |  5 +-
 configs/fast_rcnn_r50_c4_1x.py                |  5 +-
 configs/faster_rcnn_r50_c4_1x.py              |  5 +-
 configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py |  8 +-
 .../mask_rcnn_r50_fpn_gn_ws_20_23_24e.py      | 10 +--
 configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py   | 10 +--
 .../mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py      | 10 +--
 configs/gn/mask_rcnn_r101_fpn_gn_2x.py        | 10 +--
 configs/gn/mask_rcnn_r50_fpn_gn_2x.py         | 10 +--
 configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py | 10 +--
 configs/mask_rcnn_r50_c4_1x.py                |  5 +-
 configs/rpn_r50_c4_1x.py                      |  2 +-
 mmdet/models/anchor_heads/retina_head.py      | 10 +--
 mmdet/models/backbones/resnet.py              | 47 ++++++------
 mmdet/models/backbones/resnext.py             | 18 ++---
 mmdet/models/bbox_heads/convfc_bbox_head.py   |  8 +-
 mmdet/models/mask_heads/fcn_mask_head.py      |  8 +-
 .../models/mask_heads/fused_semantic_head.py  | 14 ++--
 mmdet/models/mask_heads/htc_mask_head.py      |  3 +-
 mmdet/models/necks/fpn.py                     | 13 ++--
 mmdet/models/shared_heads/res_layer.py        |  6 +-
 mmdet/models/utils/conv_module.py             | 73 ++++++++++++++-----
 23 files changed, 159 insertions(+), 136 deletions(-)

diff --git a/configs/cascade_mask_rcnn_r50_c4_1x.py b/configs/cascade_mask_rcnn_r50_c4_1x.py
index d60a165..a31e518 100644
--- a/configs/cascade_mask_rcnn_r50_c4_1x.py
+++ b/configs/cascade_mask_rcnn_r50_c4_1x.py
@@ -1,4 +1,5 @@
 # model settings
+norm_cfg = dict(type='BN', requires_grad=False)
 model = dict(
     type='CascadeRCNN',
     num_stages=3,
@@ -11,7 +12,7 @@ model = dict(
         dilations=(1, 1, 1),
         out_indices=(2, ),
         frozen_stages=1,
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True,
         style='caffe'),
     shared_head=dict(
@@ -21,7 +22,7 @@ model = dict(
         stride=2,
         dilation=1,
         style='caffe',
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True),
     rpn_head=dict(
         type='RPNHead',
diff --git a/configs/cascade_rcnn_r50_c4_1x.py b/configs/cascade_rcnn_r50_c4_1x.py
index d2ea3d0..8224f2f 100644
--- a/configs/cascade_rcnn_r50_c4_1x.py
+++ b/configs/cascade_rcnn_r50_c4_1x.py
@@ -1,4 +1,5 @@
 # model settings
+norm_cfg = dict(type='BN', requires_grad=False)
 model = dict(
     type='CascadeRCNN',
     num_stages=3,
@@ -11,7 +12,7 @@ model = dict(
         dilations=(1, 1, 1),
         out_indices=(2, ),
         frozen_stages=1,
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True,
         style='caffe'),
     shared_head=dict(
@@ -21,7 +22,7 @@ model = dict(
         stride=2,
         dilation=1,
         style='caffe',
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True),
     rpn_head=dict(
         type='RPNHead',
diff --git a/configs/fast_rcnn_r50_c4_1x.py b/configs/fast_rcnn_r50_c4_1x.py
index 052e578..a204d13 100644
--- a/configs/fast_rcnn_r50_c4_1x.py
+++ b/configs/fast_rcnn_r50_c4_1x.py
@@ -1,4 +1,5 @@
 # model settings
+norm_cfg = dict(type='BN', requires_grad=False)
 model = dict(
     type='FastRCNN',
     pretrained='open-mmlab://resnet50_caffe',
@@ -10,7 +11,7 @@ model = dict(
         dilations=(1, 1, 1),
         out_indices=(2, ),
         frozen_stages=1,
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True,
         style='caffe'),
     shared_head=dict(
@@ -20,7 +21,7 @@ model = dict(
         stride=2,
         dilation=1,
         style='caffe',
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True),
     bbox_roi_extractor=dict(
         type='SingleRoIExtractor',
diff --git a/configs/faster_rcnn_r50_c4_1x.py b/configs/faster_rcnn_r50_c4_1x.py
index 0118320..61ae19d 100644
--- a/configs/faster_rcnn_r50_c4_1x.py
+++ b/configs/faster_rcnn_r50_c4_1x.py
@@ -1,4 +1,5 @@
 # model settings
+norm_cfg = dict(type='BN', requires_grad=False)
 model = dict(
     type='FasterRCNN',
     pretrained='open-mmlab://resnet50_caffe',
@@ -10,7 +11,7 @@ model = dict(
         dilations=(1, 1, 1),
         out_indices=(2, ),
         frozen_stages=1,
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True,
         style='caffe'),
     shared_head=dict(
@@ -20,7 +21,7 @@ model = dict(
         stride=2,
         dilation=1,
         style='caffe',
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True),
     rpn_head=dict(
         type='RPNHead',
diff --git a/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py
index 067703d..ef63583 100644
--- a/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py
+++ b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py
@@ -1,6 +1,6 @@
 # model settings
 conv_cfg = dict(type='ConvWS')
-normalize = dict(type='GN', num_groups=32, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
 model = dict(
     type='FasterRCNN',
     pretrained='open-mmlab://jhu/resnet50_gn_ws',
@@ -12,14 +12,14 @@ model = dict(
         frozen_stages=1,
         style='pytorch',
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5,
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     rpn_head=dict(
         type='RPNHead',
         in_channels=256,
@@ -48,7 +48,7 @@ model = dict(
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False,
         conv_cfg=conv_cfg,
-        normalize=normalize))
+        norm_cfg=norm_cfg))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py
index 12a0c2b..5a79bfd 100644
--- a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py
+++ b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py
@@ -1,6 +1,6 @@
 # model settings
 conv_cfg = dict(type='ConvWS')
-normalize = dict(type='GN', num_groups=32, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
 model = dict(
     type='MaskRCNN',
     pretrained='open-mmlab://jhu/resnet50_gn_ws',
@@ -12,14 +12,14 @@ model = dict(
         frozen_stages=1,
         style='pytorch',
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5,
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     rpn_head=dict(
         type='RPNHead',
         in_channels=256,
@@ -48,7 +48,7 @@ model = dict(
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False,
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     mask_roi_extractor=dict(
         type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
@@ -61,7 +61,7 @@ model = dict(
         conv_out_channels=256,
         num_classes=81,
         conv_cfg=conv_cfg,
-        normalize=normalize))
+        norm_cfg=norm_cfg))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py
index 7ae7740..7294aee 100644
--- a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py
+++ b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py
@@ -1,6 +1,6 @@
 # model settings
 conv_cfg = dict(type='ConvWS')
-normalize = dict(type='GN', num_groups=32, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
 model = dict(
     type='MaskRCNN',
     pretrained='open-mmlab://jhu/resnet50_gn_ws',
@@ -12,14 +12,14 @@ model = dict(
         frozen_stages=1,
         style='pytorch',
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5,
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     rpn_head=dict(
         type='RPNHead',
         in_channels=256,
@@ -48,7 +48,7 @@ model = dict(
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False,
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     mask_roi_extractor=dict(
         type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
@@ -61,7 +61,7 @@ model = dict(
         conv_out_channels=256,
         num_classes=81,
         conv_cfg=conv_cfg,
-        normalize=normalize))
+        norm_cfg=norm_cfg))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py b/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py
index 1d5fb98..4ed83b2 100644
--- a/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py
+++ b/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py
@@ -1,6 +1,6 @@
 # model settings
 conv_cfg = dict(type='ConvWS')
-normalize = dict(type='GN', num_groups=32, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
 model = dict(
     type='MaskRCNN',
     pretrained='open-mmlab://jhu/resnext101_32x4d_gn_ws',
@@ -14,14 +14,14 @@ model = dict(
         frozen_stages=1,
         style='pytorch',
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5,
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     rpn_head=dict(
         type='RPNHead',
         in_channels=256,
@@ -50,7 +50,7 @@ model = dict(
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False,
         conv_cfg=conv_cfg,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     mask_roi_extractor=dict(
         type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
@@ -63,7 +63,7 @@ model = dict(
         conv_out_channels=256,
         num_classes=81,
         conv_cfg=conv_cfg,
-        normalize=normalize))
+        norm_cfg=norm_cfg))
 # model training and testing settings
 train_cfg = dict(
     rpn=dict(
diff --git a/configs/gn/mask_rcnn_r101_fpn_gn_2x.py b/configs/gn/mask_rcnn_r101_fpn_gn_2x.py
index 97b0c95..3f61dc4 100644
--- a/configs/gn/mask_rcnn_r101_fpn_gn_2x.py
+++ b/configs/gn/mask_rcnn_r101_fpn_gn_2x.py
@@ -1,5 +1,5 @@
 # model settings
-normalize = dict(type='GN', num_groups=32, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
 
 model = dict(
     type='MaskRCNN',
@@ -11,13 +11,13 @@ model = dict(
         out_indices=(0, 1, 2, 3),
         frozen_stages=1,
         style='pytorch',
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     rpn_head=dict(
         type='RPNHead',
         in_channels=256,
@@ -45,7 +45,7 @@ model = dict(
         target_means=[0., 0., 0., 0.],
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     mask_roi_extractor=dict(
         type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
@@ -57,7 +57,7 @@ model = dict(
         in_channels=256,
         conv_out_channels=256,
         num_classes=81,
-        normalize=normalize))
+        norm_cfg=norm_cfg))
 
 # model training and testing settings
 train_cfg = dict(
diff --git a/configs/gn/mask_rcnn_r50_fpn_gn_2x.py b/configs/gn/mask_rcnn_r50_fpn_gn_2x.py
index edbd7d8..165c4aa 100644
--- a/configs/gn/mask_rcnn_r50_fpn_gn_2x.py
+++ b/configs/gn/mask_rcnn_r50_fpn_gn_2x.py
@@ -1,5 +1,5 @@
 # model settings
-normalize = dict(type='GN', num_groups=32, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
 
 model = dict(
     type='MaskRCNN',
@@ -11,13 +11,13 @@ model = dict(
         out_indices=(0, 1, 2, 3),
         frozen_stages=1,
         style='pytorch',
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     rpn_head=dict(
         type='RPNHead',
         in_channels=256,
@@ -45,7 +45,7 @@ model = dict(
         target_means=[0., 0., 0., 0.],
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     mask_roi_extractor=dict(
         type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
@@ -57,7 +57,7 @@ model = dict(
         in_channels=256,
         conv_out_channels=256,
         num_classes=81,
-        normalize=normalize))
+        norm_cfg=norm_cfg))
 
 # model training and testing settings
 train_cfg = dict(
diff --git a/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py b/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py
index 5c63992..00760fb 100644
--- a/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py
+++ b/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py
@@ -1,5 +1,5 @@
 # model settings
-normalize = dict(type='GN', num_groups=32, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
 
 model = dict(
     type='MaskRCNN',
@@ -11,13 +11,13 @@ model = dict(
         out_indices=(0, 1, 2, 3),
         frozen_stages=1,
         style='pytorch',
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     neck=dict(
         type='FPN',
         in_channels=[256, 512, 1024, 2048],
         out_channels=256,
         num_outs=5,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     rpn_head=dict(
         type='RPNHead',
         in_channels=256,
@@ -45,7 +45,7 @@ model = dict(
         target_means=[0., 0., 0., 0.],
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False,
-        normalize=normalize),
+        norm_cfg=norm_cfg),
     mask_roi_extractor=dict(
         type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
@@ -57,7 +57,7 @@ model = dict(
         in_channels=256,
         conv_out_channels=256,
         num_classes=81,
-        normalize=normalize))
+        norm_cfg=norm_cfg))
 
 # model training and testing settings
 train_cfg = dict(
diff --git a/configs/mask_rcnn_r50_c4_1x.py b/configs/mask_rcnn_r50_c4_1x.py
index 63884a4..e7584be 100644
--- a/configs/mask_rcnn_r50_c4_1x.py
+++ b/configs/mask_rcnn_r50_c4_1x.py
@@ -1,4 +1,5 @@
 # model settings
+norm_cfg = dict(type='BN', requires_grad=False)
 model = dict(
     type='MaskRCNN',
     pretrained='open-mmlab://resnet50_caffe',
@@ -10,7 +11,7 @@ model = dict(
         dilations=(1, 1, 1),
         out_indices=(2, ),
         frozen_stages=1,
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True,
         style='caffe'),
     shared_head=dict(
@@ -20,7 +21,7 @@ model = dict(
         stride=2,
         dilation=1,
         style='caffe',
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=norm_cfg,
         norm_eval=True),
     rpn_head=dict(
         type='RPNHead',
diff --git a/configs/rpn_r50_c4_1x.py b/configs/rpn_r50_c4_1x.py
index bc3d2a8..ac9ad83 100644
--- a/configs/rpn_r50_c4_1x.py
+++ b/configs/rpn_r50_c4_1x.py
@@ -10,7 +10,7 @@ model = dict(
         dilations=(1, 1, 1),
         out_indices=(2, ),
         frozen_stages=1,
-        normalize=dict(type='BN', requires_grad=False),
+        norm_cfg=dict(type='BN', requires_grad=False),
         norm_eval=True,
         style='caffe'),
     neck=None,
diff --git a/mmdet/models/anchor_heads/retina_head.py b/mmdet/models/anchor_heads/retina_head.py
index 27701f4..815e8f0 100644
--- a/mmdet/models/anchor_heads/retina_head.py
+++ b/mmdet/models/anchor_heads/retina_head.py
@@ -17,13 +17,13 @@ class RetinaHead(AnchorHead):
                  octave_base_scale=4,
                  scales_per_octave=3,
                  conv_cfg=None,
-                 normalize=None,
+                 norm_cfg=None,
                  **kwargs):
         self.stacked_convs = stacked_convs
         self.octave_base_scale = octave_base_scale
         self.scales_per_octave = scales_per_octave
         self.conv_cfg = conv_cfg
-        self.normalize = normalize
+        self.norm_cfg = norm_cfg
         octave_scales = np.array(
             [2**(i / scales_per_octave) for i in range(scales_per_octave)])
         anchor_scales = octave_scales * octave_base_scale
@@ -49,8 +49,7 @@ class RetinaHead(AnchorHead):
                     stride=1,
                     padding=1,
                     conv_cfg=self.conv_cfg,
-                    normalize=self.normalize,
-                    bias=self.normalize is None))
+                    norm_cfg=self.norm_cfg))
             self.reg_convs.append(
                 ConvModule(
                     chn,
@@ -59,8 +58,7 @@ class RetinaHead(AnchorHead):
                     stride=1,
                     padding=1,
                     conv_cfg=self.conv_cfg,
-                    normalize=self.normalize,
-                    bias=self.normalize is None))
+                    norm_cfg=self.norm_cfg))
         self.retina_cls = nn.Conv2d(
             self.feat_channels,
             self.num_anchors * self.cls_out_channels,
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index c9daeef..037fbb1 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -23,13 +23,13 @@ class BasicBlock(nn.Module):
                  style='pytorch',
                  with_cp=False,
                  conv_cfg=None,
-                 normalize=dict(type='BN'),
+                 norm_cfg=dict(type='BN'),
                  dcn=None):
         super(BasicBlock, self).__init__()
         assert dcn is None, "Not implemented yet."
 
-        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
-        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
+        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
 
         self.conv1 = build_conv_layer(
             conv_cfg,
@@ -95,7 +95,7 @@ class Bottleneck(nn.Module):
                  style='pytorch',
                  with_cp=False,
                  conv_cfg=None,
-                 normalize=dict(type='BN'),
+                 norm_cfg=dict(type='BN'),
                  dcn=None):
         """Bottleneck block for ResNet.
         If style is "pytorch", the stride-two layer is the 3x3 conv layer,
@@ -106,21 +106,26 @@ class Bottleneck(nn.Module):
         assert dcn is None or isinstance(dcn, dict)
         self.inplanes = inplanes
         self.planes = planes
+        self.stride = stride
+        self.dilation = dilation
+        self.downsample = downsample
+        self.style = style
+        self.with_cp = with_cp
         self.conv_cfg = conv_cfg
-        self.normalize = normalize
+        self.norm_cfg = norm_cfg
         self.dcn = dcn
         self.with_dcn = dcn is not None
-        if style == 'pytorch':
+        if self.style == 'pytorch':
             self.conv1_stride = 1
             self.conv2_stride = stride
         else:
             self.conv1_stride = stride
             self.conv2_stride = 1
 
-        self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
-        self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
+        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
         self.norm3_name, norm3 = build_norm_layer(
-            normalize, planes * self.expansion, postfix=3)
+            norm_cfg, planes * self.expansion, postfix=3)
 
         self.conv1 = build_conv_layer(
             conv_cfg,
@@ -180,11 +185,6 @@ class Bottleneck(nn.Module):
         self.add_module(self.norm3_name, norm3)
 
         self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-        self.dilation = dilation
-        self.with_cp = with_cp
-        self.normalize = normalize
 
     @property
     def norm1(self):
@@ -249,7 +249,7 @@ def make_res_layer(block,
                    style='pytorch',
                    with_cp=False,
                    conv_cfg=None,
-                   normalize=dict(type='BN'),
+                   norm_cfg=dict(type='BN'),
                    dcn=None):
     downsample = None
     if stride != 1 or inplanes != planes * block.expansion:
@@ -261,7 +261,7 @@ def make_res_layer(block,
                 kernel_size=1,
                 stride=stride,
                 bias=False),
-            build_norm_layer(normalize, planes * block.expansion)[1],
+            build_norm_layer(norm_cfg, planes * block.expansion)[1],
         )
 
     layers = []
@@ -275,7 +275,7 @@ def make_res_layer(block,
             style=style,
             with_cp=with_cp,
             conv_cfg=conv_cfg,
-            normalize=normalize,
+            norm_cfg=norm_cfg,
             dcn=dcn))
     inplanes = planes * block.expansion
     for i in range(1, blocks):
@@ -288,7 +288,7 @@ def make_res_layer(block,
                 style=style,
                 with_cp=with_cp,
                 conv_cfg=conv_cfg,
-                normalize=normalize,
+                norm_cfg=norm_cfg,
                 dcn=dcn))
 
     return nn.Sequential(*layers)
@@ -309,7 +309,7 @@ class ResNet(nn.Module):
             the first 1x1 conv layer.
         frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
             -1 means not freezing any parameters.
-        normalize (dict): dictionary to construct and config norm layer.
+        norm_cfg (dict): dictionary to construct and config norm layer.
         norm_eval (bool): Whether to set norm layers to eval mode, namely,
             freeze running stats (mean and var). Note: Effect on Batch Norm
             and its variants only.
@@ -336,7 +336,7 @@ class ResNet(nn.Module):
                  style='pytorch',
                  frozen_stages=-1,
                  conv_cfg=None,
-                 normalize=dict(type='BN', requires_grad=True),
+                 norm_cfg=dict(type='BN', requires_grad=True),
                  norm_eval=True,
                  dcn=None,
                  stage_with_dcn=(False, False, False, False),
@@ -356,7 +356,7 @@ class ResNet(nn.Module):
         self.style = style
         self.frozen_stages = frozen_stages
         self.conv_cfg = conv_cfg
-        self.normalize = normalize
+        self.norm_cfg = norm_cfg
         self.with_cp = with_cp
         self.norm_eval = norm_eval
         self.dcn = dcn
@@ -386,7 +386,7 @@ class ResNet(nn.Module):
                 style=self.style,
                 with_cp=with_cp,
                 conv_cfg=conv_cfg,
-                normalize=normalize,
+                norm_cfg=norm_cfg,
                 dcn=dcn)
             self.inplanes = planes * self.block.expansion
             layer_name = 'layer{}'.format(i + 1)
@@ -411,8 +411,7 @@ class ResNet(nn.Module):
             stride=2,
             padding=3,
             bias=False)
-        self.norm1_name, norm1 = build_norm_layer(
-            self.normalize, 64, postfix=1)
+        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
         self.add_module(self.norm1_name, norm1)
         self.relu = nn.ReLU(inplace=True)
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
index 3cdb1cb..c869a02 100644
--- a/mmdet/models/backbones/resnext.py
+++ b/mmdet/models/backbones/resnext.py
@@ -24,11 +24,11 @@ class Bottleneck(_Bottleneck):
             width = math.floor(self.planes * (base_width / 64)) * groups
 
         self.norm1_name, norm1 = build_norm_layer(
-            self.normalize, width, postfix=1)
+            self.norm_cfg, width, postfix=1)
         self.norm2_name, norm2 = build_norm_layer(
-            self.normalize, width, postfix=2)
+            self.norm_cfg, width, postfix=2)
         self.norm3_name, norm3 = build_norm_layer(
-            self.normalize, self.planes * self.expansion, postfix=3)
+            self.norm_cfg, self.planes * self.expansion, postfix=3)
 
         self.conv1 = build_conv_layer(
             self.conv_cfg,
@@ -102,7 +102,7 @@ def make_res_layer(block,
                    style='pytorch',
                    with_cp=False,
                    conv_cfg=None,
-                   normalize=dict(type='BN'),
+                   norm_cfg=dict(type='BN'),
                    dcn=None):
     downsample = None
     if stride != 1 or inplanes != planes * block.expansion:
@@ -114,7 +114,7 @@ def make_res_layer(block,
                 kernel_size=1,
                 stride=stride,
                 bias=False),
-            build_norm_layer(normalize, planes * block.expansion)[1],
+            build_norm_layer(norm_cfg, planes * block.expansion)[1],
         )
 
     layers = []
@@ -130,7 +130,7 @@ def make_res_layer(block,
             style=style,
             with_cp=with_cp,
             conv_cfg=conv_cfg,
-            normalize=normalize,
+            norm_cfg=norm_cfg,
             dcn=dcn))
     inplanes = planes * block.expansion
     for i in range(1, blocks):
@@ -145,7 +145,7 @@ def make_res_layer(block,
                 style=style,
                 with_cp=with_cp,
                 conv_cfg=conv_cfg,
-                normalize=normalize,
+                norm_cfg=norm_cfg,
                 dcn=dcn))
 
     return nn.Sequential(*layers)
@@ -168,7 +168,7 @@ class ResNeXt(ResNet):
             the first 1x1 conv layer.
         frozen_stages (int): Stages to be frozen (all param fixed). -1 means
             not freezing any parameters.
-        normalize (dict): dictionary to construct and config norm layer.
+        norm_cfg (dict): dictionary to construct and config norm layer.
         norm_eval (bool): Whether to set norm layers to eval mode, namely,
             freeze running stats (mean and var). Note: Effect on Batch Norm
             and its variants only.
@@ -208,7 +208,7 @@ class ResNeXt(ResNet):
                 style=self.style,
                 with_cp=self.with_cp,
                 conv_cfg=self.conv_cfg,
-                normalize=self.normalize,
+                norm_cfg=self.norm_cfg,
                 dcn=dcn)
             self.inplanes = planes * self.block.expansion
             layer_name = 'layer{}'.format(i + 1)
diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py
index af2a5e3..c0e7154 100644
--- a/mmdet/models/bbox_heads/convfc_bbox_head.py
+++ b/mmdet/models/bbox_heads/convfc_bbox_head.py
@@ -25,7 +25,7 @@ class ConvFCBBoxHead(BBoxHead):
                  conv_out_channels=256,
                  fc_out_channels=1024,
                  conv_cfg=None,
-                 normalize=None,
+                 norm_cfg=None,
                  *args,
                  **kwargs):
         super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
@@ -46,8 +46,7 @@ class ConvFCBBoxHead(BBoxHead):
         self.conv_out_channels = conv_out_channels
         self.fc_out_channels = fc_out_channels
         self.conv_cfg = conv_cfg
-        self.normalize = normalize
-        self.with_bias = normalize is None
+        self.norm_cfg = norm_cfg
 
         # add shared convs and fcs
         self.shared_convs, self.shared_fcs, last_layer_dim = \
@@ -104,8 +103,7 @@ class ConvFCBBoxHead(BBoxHead):
                         3,
                         padding=1,
                         conv_cfg=self.conv_cfg,
-                        normalize=self.normalize,
-                        bias=self.with_bias))
+                        norm_cfg=self.norm_cfg))
             last_layer_dim = self.conv_out_channels
         # add branch specific fc layers
         branch_fcs = nn.ModuleList()
diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py
index 614497a..9274e7c 100644
--- a/mmdet/models/mask_heads/fcn_mask_head.py
+++ b/mmdet/models/mask_heads/fcn_mask_head.py
@@ -23,7 +23,7 @@ class FCNMaskHead(nn.Module):
                  num_classes=81,
                  class_agnostic=False,
                  conv_cfg=None,
-                 normalize=None):
+                 norm_cfg=None):
         super(FCNMaskHead, self).__init__()
         if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
             raise ValueError(
@@ -39,8 +39,7 @@ class FCNMaskHead(nn.Module):
         self.num_classes = num_classes
         self.class_agnostic = class_agnostic
         self.conv_cfg = conv_cfg
-        self.normalize = normalize
-        self.with_bias = normalize is None
+        self.norm_cfg = norm_cfg
 
         self.convs = nn.ModuleList()
         for i in range(self.num_convs):
@@ -54,8 +53,7 @@ class FCNMaskHead(nn.Module):
                     self.conv_kernel_size,
                     padding=padding,
                     conv_cfg=conv_cfg,
-                    normalize=normalize,
-                    bias=self.with_bias))
+                    norm_cfg=norm_cfg))
         upsample_in_channels = (self.conv_out_channels
                                 if self.num_convs > 0 else in_channels)
         if self.upsample_method is None:
diff --git a/mmdet/models/mask_heads/fused_semantic_head.py b/mmdet/models/mask_heads/fused_semantic_head.py
index f24adf3..12b8302 100644
--- a/mmdet/models/mask_heads/fused_semantic_head.py
+++ b/mmdet/models/mask_heads/fused_semantic_head.py
@@ -31,7 +31,7 @@ class FusedSemanticHead(nn.Module):
                  ignore_label=255,
                  loss_weight=0.2,
                  conv_cfg=None,
-                 normalize=None):
+                 norm_cfg=None):
         super(FusedSemanticHead, self).__init__()
         self.num_ins = num_ins
         self.fusion_level = fusion_level
@@ -42,8 +42,7 @@ class FusedSemanticHead(nn.Module):
         self.ignore_label = ignore_label
         self.loss_weight = loss_weight
         self.conv_cfg = conv_cfg
-        self.normalize = normalize
-        self.with_bias = normalize is None
+        self.norm_cfg = norm_cfg
 
         self.lateral_convs = nn.ModuleList()
         for i in range(self.num_ins):
@@ -53,8 +52,7 @@ class FusedSemanticHead(nn.Module):
                     self.in_channels,
                     1,
                     conv_cfg=self.conv_cfg,
-                    normalize=self.normalize,
-                    bias=self.with_bias,
+                    norm_cfg=self.norm_cfg,
                     inplace=False))
 
         self.convs = nn.ModuleList()
@@ -67,15 +65,13 @@ class FusedSemanticHead(nn.Module):
                     3,
                     padding=1,
                     conv_cfg=self.conv_cfg,
-                    normalize=self.normalize,
-                    bias=self.with_bias))
+                    norm_cfg=self.norm_cfg))
         self.conv_embedding = ConvModule(
             conv_out_channels,
             conv_out_channels,
             1,
             conv_cfg=self.conv_cfg,
-            normalize=self.normalize,
-            bias=self.with_bias)
+            norm_cfg=self.norm_cfg)
         self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
 
         self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label)
diff --git a/mmdet/models/mask_heads/htc_mask_head.py b/mmdet/models/mask_heads/htc_mask_head.py
index 21f3130..9ba3ed7 100644
--- a/mmdet/models/mask_heads/htc_mask_head.py
+++ b/mmdet/models/mask_heads/htc_mask_head.py
@@ -13,8 +13,7 @@ class HTCMaskHead(FCNMaskHead):
             self.conv_out_channels,
             1,
             conv_cfg=self.conv_cfg,
-            normalize=self.normalize,
-            bias=self.with_bias)
+            norm_cfg=self.norm_cfg)
 
     def init_weights(self):
         super(HTCMaskHead, self).init_weights()
diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
index 3e49fc4..7b33b69 100644
--- a/mmdet/models/necks/fpn.py
+++ b/mmdet/models/necks/fpn.py
@@ -18,7 +18,7 @@ class FPN(nn.Module):
                  add_extra_convs=False,
                  extra_convs_on_inputs=True,
                  conv_cfg=None,
-                 normalize=None,
+                 norm_cfg=None,
                  activation=None):
         super(FPN, self).__init__()
         assert isinstance(in_channels, list)
@@ -27,7 +27,6 @@ class FPN(nn.Module):
         self.num_ins = len(in_channels)
         self.num_outs = num_outs
         self.activation = activation
-        self.with_bias = normalize is None
 
         if end_level == -1:
             self.backbone_end_level = self.num_ins
@@ -51,8 +50,7 @@ class FPN(nn.Module):
                 out_channels,
                 1,
                 conv_cfg=conv_cfg,
-                normalize=normalize,
-                bias=self.with_bias,
+                norm_cfg=norm_cfg,
                 activation=self.activation,
                 inplace=False)
             fpn_conv = ConvModule(
@@ -61,8 +59,7 @@ class FPN(nn.Module):
                 3,
                 padding=1,
                 conv_cfg=conv_cfg,
-                normalize=normalize,
-                bias=self.with_bias,
+                norm_cfg=norm_cfg,
                 activation=self.activation,
                 inplace=False)
 
@@ -83,8 +80,8 @@ class FPN(nn.Module):
                     3,
                     stride=2,
                     padding=1,
-                    normalize=normalize,
-                    bias=self.with_bias,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
                     activation=self.activation,
                     inplace=False)
                 self.fpn_convs.append(extra_fpn_conv)
diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py
index ea306e5..743c2ee 100644
--- a/mmdet/models/shared_heads/res_layer.py
+++ b/mmdet/models/shared_heads/res_layer.py
@@ -17,13 +17,13 @@ class ResLayer(nn.Module):
                  stride=2,
                  dilation=1,
                  style='pytorch',
-                 normalize=dict(type='BN', requires_grad=True),
+                 norm_cfg=dict(type='BN', requires_grad=True),
                  norm_eval=True,
                  with_cp=False,
                  dcn=None):
         super(ResLayer, self).__init__()
         self.norm_eval = norm_eval
-        self.normalize = normalize
+        self.norm_cfg = norm_cfg
         self.stage = stage
         block, stage_blocks = ResNet.arch_settings[depth]
         stage_block = stage_blocks[stage]
@@ -39,7 +39,7 @@ class ResLayer(nn.Module):
             dilation=dilation,
             style=style,
             with_cp=with_cp,
-            normalize=self.normalize,
+            norm_cfg=self.norm_cfg,
             dcn=dcn)
         self.add_module('layer{}'.format(stage + 1), res_layer)
 
diff --git a/mmdet/models/utils/conv_module.py b/mmdet/models/utils/conv_module.py
index b3bf9c7..329b235 100644
--- a/mmdet/models/utils/conv_module.py
+++ b/mmdet/models/utils/conv_module.py
@@ -42,6 +42,27 @@ def build_conv_layer(cfg, *args, **kwargs):
 
 
 class ConvModule(nn.Module):
+    """Conv-Norm-Activation block.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int or tuple[int]): Same as nn.Conv2d.
+        padding (int or tuple[int]): Same as nn.Conv2d.
+        dilation (int or tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+        conv_cfg (dict): Config dict for convolution layer.
+        norm_cfg (dict): Config dict for normalization layer.
+        activation (str or None): Activation type, "ReLU" by default.
+        inplace (bool): Whether to use inplace mode for activation.
+        activate_last (bool): Whether to apply the activation layer in the
+            last. (Do not use this flag since the behavior and api may be
+            changed in the future.)
+    """
 
     def __init__(self,
                  in_channels,
@@ -51,35 +72,42 @@ class ConvModule(nn.Module):
                  padding=0,
                  dilation=1,
                  groups=1,
-                 bias=True,
+                 bias='auto',
                  conv_cfg=None,
-                 normalize=None,
+                 norm_cfg=None,
                  activation='relu',
                  inplace=True,
                  activate_last=True):
         super(ConvModule, self).__init__()
         assert conv_cfg is None or isinstance(conv_cfg, dict)
-        assert normalize is None or isinstance(normalize, dict)
-        self.with_norm = normalize is not None
-        self.with_activatation = activation is not None
-        self.with_bias = bias
+        assert norm_cfg is None or isinstance(norm_cfg, dict)
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
         self.activation = activation
+        self.inplace = inplace
         self.activate_last = activate_last
 
+        self.with_norm = norm_cfg is not None
+        self.with_activatation = activation is not None
+        # if the conv layer is before a norm layer, bias is unnecessary.
+        if bias == 'auto':
+            bias = False if self.with_norm else True
+        self.with_bias = bias
+
         if self.with_norm and self.with_bias:
             warnings.warn('ConvModule has norm and bias at the same time')
 
-        self.conv = build_conv_layer(
-            conv_cfg,
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride,
-            padding,
-            dilation,
-            groups,
-            bias=bias)
-
+        # build convolution layer
+        self.conv = build_conv_layer(conv_cfg,
+                                     in_channels,
+                                     out_channels,
+                                     kernel_size,
+                                     stride=stride,
+                                     padding=padding,
+                                     dilation=dilation,
+                                     groups=groups,
+                                     bias=bias)
+        # export the attributes of self.conv to a higher level for convenience
         self.in_channels = self.conv.in_channels
         self.out_channels = self.conv.out_channels
         self.kernel_size = self.conv.kernel_size
@@ -90,17 +118,21 @@ class ConvModule(nn.Module):
         self.output_padding = self.conv.output_padding
         self.groups = self.conv.groups
 
+        # build normalization layers
         if self.with_norm:
             norm_channels = out_channels if self.activate_last else in_channels
-            self.norm_name, norm = build_norm_layer(normalize, norm_channels)
+            self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
             self.add_module(self.norm_name, norm)
 
+        # build activation layer
         if self.with_activatation:
-            assert activation in ['relu'], 'Only ReLU supported.'
+            if self.activation not in ['relu']:
+                raise ValueError('{} is currently not supported.'.format(
+                    self.activation))
             if self.activation == 'relu':
                 self.activate = nn.ReLU(inplace=inplace)
 
-        # Default using msra init
+        # Use msra init by default
         self.init_weights()
 
     @property
@@ -121,6 +153,7 @@ class ConvModule(nn.Module):
             if activate and self.with_activatation:
                 x = self.activate(x)
         else:
+            # WARN: this may be removed or modified
             if norm and self.with_norm:
                 x = self.norm(x)
             if activate and self.with_activatation:
-- 
GitLab