diff --git a/configs/cascade_mask_rcnn_r50_c4_1x.py b/configs/cascade_mask_rcnn_r50_c4_1x.py index d60a165c10341289234453e49f1fbc753e26db40..a31e5187184fa9e2e9e4761d16e7fa556bbf074a 100644 --- a/configs/cascade_mask_rcnn_r50_c4_1x.py +++ b/configs/cascade_mask_rcnn_r50_c4_1x.py @@ -1,4 +1,5 @@ # model settings +norm_cfg = dict(type='BN', requires_grad=False) model = dict( type='CascadeRCNN', num_stages=3, @@ -11,7 +12,7 @@ model = dict( dilations=(1, 1, 1), out_indices=(2, ), frozen_stages=1, - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True, style='caffe'), shared_head=dict( @@ -21,7 +22,7 @@ model = dict( stride=2, dilation=1, style='caffe', - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True), rpn_head=dict( type='RPNHead', diff --git a/configs/cascade_rcnn_r50_c4_1x.py b/configs/cascade_rcnn_r50_c4_1x.py index d2ea3d0457a1a2813677861dd4f315ff528830e0..8224f2f1e0f96a1a2ff6c2dff99562ee1bf3a299 100644 --- a/configs/cascade_rcnn_r50_c4_1x.py +++ b/configs/cascade_rcnn_r50_c4_1x.py @@ -1,4 +1,5 @@ # model settings +norm_cfg = dict(type='BN', requires_grad=False) model = dict( type='CascadeRCNN', num_stages=3, @@ -11,7 +12,7 @@ model = dict( dilations=(1, 1, 1), out_indices=(2, ), frozen_stages=1, - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True, style='caffe'), shared_head=dict( @@ -21,7 +22,7 @@ model = dict( stride=2, dilation=1, style='caffe', - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True), rpn_head=dict( type='RPNHead', diff --git a/configs/fast_rcnn_r50_c4_1x.py b/configs/fast_rcnn_r50_c4_1x.py index 052e5780cf8e9b632adba3034054a052dec3c30a..a204d13d25e29a90670f1d9572974c2a846187ce 100644 --- a/configs/fast_rcnn_r50_c4_1x.py +++ b/configs/fast_rcnn_r50_c4_1x.py @@ -1,4 +1,5 @@ # model settings +norm_cfg = dict(type='BN', requires_grad=False) model = dict( type='FastRCNN', pretrained='open-mmlab://resnet50_caffe', @@ -10,7 +11,7 @@ model = dict( dilations=(1, 1, 1), out_indices=(2, ), frozen_stages=1, - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True, style='caffe'), shared_head=dict( @@ -20,7 +21,7 @@ model = dict( stride=2, dilation=1, style='caffe', - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True), bbox_roi_extractor=dict( type='SingleRoIExtractor', diff --git a/configs/faster_rcnn_r50_c4_1x.py b/configs/faster_rcnn_r50_c4_1x.py index 01183203bfbebf3a69d8c4788dbe48d0f9136ab3..61ae19ddde4f150644bdd71c13edff08dc6de0e1 100644 --- a/configs/faster_rcnn_r50_c4_1x.py +++ b/configs/faster_rcnn_r50_c4_1x.py @@ -1,4 +1,5 @@ # model settings +norm_cfg = dict(type='BN', requires_grad=False) model = dict( type='FasterRCNN', pretrained='open-mmlab://resnet50_caffe', @@ -10,7 +11,7 @@ model = dict( dilations=(1, 1, 1), out_indices=(2, ), frozen_stages=1, - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True, style='caffe'), shared_head=dict( @@ -20,7 +21,7 @@ model = dict( stride=2, dilation=1, style='caffe', - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True), rpn_head=dict( type='RPNHead', diff --git a/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py index 067703db1473c30562e58547e082bf26a8153660..ef635839bb026a747154622285a67cef6a37e600 100644 --- a/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py +++ b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py @@ -1,6 +1,6 @@ # model settings conv_cfg = dict(type='ConvWS') -normalize = dict(type='GN', num_groups=32, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) model = dict( type='FasterRCNN', pretrained='open-mmlab://jhu/resnet50_gn_ws', @@ -12,14 +12,14 @@ model = dict( frozen_stages=1, style='pytorch', conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), rpn_head=dict( type='RPNHead', in_channels=256, @@ -48,7 +48,7 @@ model = dict( target_stds=[0.1, 0.1, 0.2, 0.2], reg_class_agnostic=False, conv_cfg=conv_cfg, - normalize=normalize)) + norm_cfg=norm_cfg)) # model training and testing settings train_cfg = dict( rpn=dict( diff --git a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py index 12a0c2b7aa006c42df631959cb55ff14bc7f88fd..5a79bfd6c4ac4e21d74ebc12c3aaf1b5194775da 100644 --- a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py +++ b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py @@ -1,6 +1,6 @@ # model settings conv_cfg = dict(type='ConvWS') -normalize = dict(type='GN', num_groups=32, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) model = dict( type='MaskRCNN', pretrained='open-mmlab://jhu/resnet50_gn_ws', @@ -12,14 +12,14 @@ model = dict( frozen_stages=1, style='pytorch', conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), rpn_head=dict( type='RPNHead', in_channels=256, @@ -48,7 +48,7 @@ model = dict( target_stds=[0.1, 0.1, 0.2, 0.2], reg_class_agnostic=False, conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), @@ -61,7 +61,7 @@ model = dict( conv_out_channels=256, num_classes=81, conv_cfg=conv_cfg, - normalize=normalize)) + norm_cfg=norm_cfg)) # model training and testing settings train_cfg = dict( rpn=dict( diff --git a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py index 7ae7740e27c687e0368aeddb92b89d24ea314f64..7294aee69286503898d9747d724d36d975380a05 100644 --- a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py +++ b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py @@ -1,6 +1,6 @@ # model settings conv_cfg = dict(type='ConvWS') -normalize = dict(type='GN', num_groups=32, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) model = dict( type='MaskRCNN', pretrained='open-mmlab://jhu/resnet50_gn_ws', @@ -12,14 +12,14 @@ model = dict( frozen_stages=1, style='pytorch', conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), rpn_head=dict( type='RPNHead', in_channels=256, @@ -48,7 +48,7 @@ model = dict( target_stds=[0.1, 0.1, 0.2, 0.2], reg_class_agnostic=False, conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), @@ -61,7 +61,7 @@ model = dict( conv_out_channels=256, num_classes=81, conv_cfg=conv_cfg, - normalize=normalize)) + norm_cfg=norm_cfg)) # model training and testing settings train_cfg = dict( rpn=dict( diff --git a/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py b/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py index 1d5fb980c951a24d11f6fc4a6375365f2680b4a6..4ed83b251be8b6fcb0ad0eb55043e913b5c9b4c8 100644 --- a/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py +++ b/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py @@ -1,6 +1,6 @@ # model settings conv_cfg = dict(type='ConvWS') -normalize = dict(type='GN', num_groups=32, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) model = dict( type='MaskRCNN', pretrained='open-mmlab://jhu/resnext101_32x4d_gn_ws', @@ -14,14 +14,14 @@ model = dict( frozen_stages=1, style='pytorch', conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), rpn_head=dict( type='RPNHead', in_channels=256, @@ -50,7 +50,7 @@ model = dict( target_stds=[0.1, 0.1, 0.2, 0.2], reg_class_agnostic=False, conv_cfg=conv_cfg, - normalize=normalize), + norm_cfg=norm_cfg), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), @@ -63,7 +63,7 @@ model = dict( conv_out_channels=256, num_classes=81, conv_cfg=conv_cfg, - normalize=normalize)) + norm_cfg=norm_cfg)) # model training and testing settings train_cfg = dict( rpn=dict( diff --git a/configs/gn/mask_rcnn_r101_fpn_gn_2x.py b/configs/gn/mask_rcnn_r101_fpn_gn_2x.py index 97b0c95c7d1b494009a188d309c90a9342b3eae9..3f61dc402a534b59010b1b5c0acfd1fe3c557605 100644 --- a/configs/gn/mask_rcnn_r101_fpn_gn_2x.py +++ b/configs/gn/mask_rcnn_r101_fpn_gn_2x.py @@ -1,5 +1,5 @@ # model settings -normalize = dict(type='GN', num_groups=32, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) model = dict( type='MaskRCNN', @@ -11,13 +11,13 @@ model = dict( out_indices=(0, 1, 2, 3), frozen_stages=1, style='pytorch', - normalize=normalize), + norm_cfg=norm_cfg), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, - normalize=normalize), + norm_cfg=norm_cfg), rpn_head=dict( type='RPNHead', in_channels=256, @@ -45,7 +45,7 @@ model = dict( target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2], reg_class_agnostic=False, - normalize=normalize), + norm_cfg=norm_cfg), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), @@ -57,7 +57,7 @@ model = dict( in_channels=256, conv_out_channels=256, num_classes=81, - normalize=normalize)) + norm_cfg=norm_cfg)) # model training and testing settings train_cfg = dict( diff --git a/configs/gn/mask_rcnn_r50_fpn_gn_2x.py b/configs/gn/mask_rcnn_r50_fpn_gn_2x.py index edbd7d8349c72d4f9ce98780a7ac9f891be9a868..165c4aa06fc018a6a4fdf489351090e8c1bbf04a 100644 --- a/configs/gn/mask_rcnn_r50_fpn_gn_2x.py +++ b/configs/gn/mask_rcnn_r50_fpn_gn_2x.py @@ -1,5 +1,5 @@ # model settings -normalize = dict(type='GN', num_groups=32, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) model = dict( type='MaskRCNN', @@ -11,13 +11,13 @@ model = dict( out_indices=(0, 1, 2, 3), frozen_stages=1, style='pytorch', - normalize=normalize), + norm_cfg=norm_cfg), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, - normalize=normalize), + norm_cfg=norm_cfg), rpn_head=dict( type='RPNHead', in_channels=256, @@ -45,7 +45,7 @@ model = dict( target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2], reg_class_agnostic=False, - normalize=normalize), + norm_cfg=norm_cfg), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), @@ -57,7 +57,7 @@ model = dict( in_channels=256, conv_out_channels=256, num_classes=81, - normalize=normalize)) + norm_cfg=norm_cfg)) # model training and testing settings train_cfg = dict( diff --git a/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py b/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py index 5c63992f0f932547e844278f90c4e873c3a4b3f9..00760fb031eb6a034b285ec2e8abe18b0f5527cc 100644 --- a/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py +++ b/configs/gn/mask_rcnn_r50_fpn_gn_contrib_2x.py @@ -1,5 +1,5 @@ # model settings -normalize = dict(type='GN', num_groups=32, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) model = dict( type='MaskRCNN', @@ -11,13 +11,13 @@ model = dict( out_indices=(0, 1, 2, 3), frozen_stages=1, style='pytorch', - normalize=normalize), + norm_cfg=norm_cfg), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5, - normalize=normalize), + norm_cfg=norm_cfg), rpn_head=dict( type='RPNHead', in_channels=256, @@ -45,7 +45,7 @@ model = dict( target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2], reg_class_agnostic=False, - normalize=normalize), + norm_cfg=norm_cfg), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), @@ -57,7 +57,7 @@ model = dict( in_channels=256, conv_out_channels=256, num_classes=81, - normalize=normalize)) + norm_cfg=norm_cfg)) # model training and testing settings train_cfg = dict( diff --git a/configs/mask_rcnn_r50_c4_1x.py b/configs/mask_rcnn_r50_c4_1x.py index 63884a4ed6746eb2e625787229ec607eeec86d26..e7584be88e6eb234ac2f25656de6b8a6918b999b 100644 --- a/configs/mask_rcnn_r50_c4_1x.py +++ b/configs/mask_rcnn_r50_c4_1x.py @@ -1,4 +1,5 @@ # model settings +norm_cfg = dict(type='BN', requires_grad=False) model = dict( type='MaskRCNN', pretrained='open-mmlab://resnet50_caffe', @@ -10,7 +11,7 @@ model = dict( dilations=(1, 1, 1), out_indices=(2, ), frozen_stages=1, - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True, style='caffe'), shared_head=dict( @@ -20,7 +21,7 @@ model = dict( stride=2, dilation=1, style='caffe', - normalize=dict(type='BN', requires_grad=False), + norm_cfg=norm_cfg, norm_eval=True), rpn_head=dict( type='RPNHead', diff --git a/configs/rpn_r50_c4_1x.py b/configs/rpn_r50_c4_1x.py index bc3d2a8f1f64b90f211bde8952ce7e631c4bf55e..ac9ad836956aa3968a84654285a02eb15f8832a2 100644 --- a/configs/rpn_r50_c4_1x.py +++ b/configs/rpn_r50_c4_1x.py @@ -10,7 +10,7 @@ model = dict( dilations=(1, 1, 1), out_indices=(2, ), frozen_stages=1, - normalize=dict(type='BN', requires_grad=False), + norm_cfg=dict(type='BN', requires_grad=False), norm_eval=True, style='caffe'), neck=None, diff --git a/mmdet/models/anchor_heads/retina_head.py b/mmdet/models/anchor_heads/retina_head.py index 27701f46e9ead7865d294f75341054092efa19a2..815e8f0f91601588e31b9ce16a821c4e6aa57873 100644 --- a/mmdet/models/anchor_heads/retina_head.py +++ b/mmdet/models/anchor_heads/retina_head.py @@ -17,13 +17,13 @@ class RetinaHead(AnchorHead): octave_base_scale=4, scales_per_octave=3, conv_cfg=None, - normalize=None, + norm_cfg=None, **kwargs): self.stacked_convs = stacked_convs self.octave_base_scale = octave_base_scale self.scales_per_octave = scales_per_octave self.conv_cfg = conv_cfg - self.normalize = normalize + self.norm_cfg = norm_cfg octave_scales = np.array( [2**(i / scales_per_octave) for i in range(scales_per_octave)]) anchor_scales = octave_scales * octave_base_scale @@ -49,8 +49,7 @@ class RetinaHead(AnchorHead): stride=1, padding=1, conv_cfg=self.conv_cfg, - normalize=self.normalize, - bias=self.normalize is None)) + norm_cfg=self.norm_cfg)) self.reg_convs.append( ConvModule( chn, @@ -59,8 +58,7 @@ class RetinaHead(AnchorHead): stride=1, padding=1, conv_cfg=self.conv_cfg, - normalize=self.normalize, - bias=self.normalize is None)) + norm_cfg=self.norm_cfg)) self.retina_cls = nn.Conv2d( self.feat_channels, self.num_anchors * self.cls_out_channels, diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index c9daeef68c81213c2809d0c985e024a2953f2cc3..037fbb14d0bb7d9d46069dba7685194c51b6d713 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -23,13 +23,13 @@ class BasicBlock(nn.Module): style='pytorch', with_cp=False, conv_cfg=None, - normalize=dict(type='BN'), + norm_cfg=dict(type='BN'), dcn=None): super(BasicBlock, self).__init__() assert dcn is None, "Not implemented yet." - self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1) - self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2) + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) self.conv1 = build_conv_layer( conv_cfg, @@ -95,7 +95,7 @@ class Bottleneck(nn.Module): style='pytorch', with_cp=False, conv_cfg=None, - normalize=dict(type='BN'), + norm_cfg=dict(type='BN'), dcn=None): """Bottleneck block for ResNet. If style is "pytorch", the stride-two layer is the 3x3 conv layer, @@ -106,21 +106,26 @@ class Bottleneck(nn.Module): assert dcn is None or isinstance(dcn, dict) self.inplanes = inplanes self.planes = planes + self.stride = stride + self.dilation = dilation + self.downsample = downsample + self.style = style + self.with_cp = with_cp self.conv_cfg = conv_cfg - self.normalize = normalize + self.norm_cfg = norm_cfg self.dcn = dcn self.with_dcn = dcn is not None - if style == 'pytorch': + if self.style == 'pytorch': self.conv1_stride = 1 self.conv2_stride = stride else: self.conv1_stride = stride self.conv2_stride = 1 - self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1) - self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2) + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) self.norm3_name, norm3 = build_norm_layer( - normalize, planes * self.expansion, postfix=3) + norm_cfg, planes * self.expansion, postfix=3) self.conv1 = build_conv_layer( conv_cfg, @@ -180,11 +185,6 @@ class Bottleneck(nn.Module): self.add_module(self.norm3_name, norm3) self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - self.with_cp = with_cp - self.normalize = normalize @property def norm1(self): @@ -249,7 +249,7 @@ def make_res_layer(block, style='pytorch', with_cp=False, conv_cfg=None, - normalize=dict(type='BN'), + norm_cfg=dict(type='BN'), dcn=None): downsample = None if stride != 1 or inplanes != planes * block.expansion: @@ -261,7 +261,7 @@ def make_res_layer(block, kernel_size=1, stride=stride, bias=False), - build_norm_layer(normalize, planes * block.expansion)[1], + build_norm_layer(norm_cfg, planes * block.expansion)[1], ) layers = [] @@ -275,7 +275,7 @@ def make_res_layer(block, style=style, with_cp=with_cp, conv_cfg=conv_cfg, - normalize=normalize, + norm_cfg=norm_cfg, dcn=dcn)) inplanes = planes * block.expansion for i in range(1, blocks): @@ -288,7 +288,7 @@ def make_res_layer(block, style=style, with_cp=with_cp, conv_cfg=conv_cfg, - normalize=normalize, + norm_cfg=norm_cfg, dcn=dcn)) return nn.Sequential(*layers) @@ -309,7 +309,7 @@ class ResNet(nn.Module): the first 1x1 conv layer. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. - normalize (dict): dictionary to construct and config norm layer. + norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. @@ -336,7 +336,7 @@ class ResNet(nn.Module): style='pytorch', frozen_stages=-1, conv_cfg=None, - normalize=dict(type='BN', requires_grad=True), + norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, dcn=None, stage_with_dcn=(False, False, False, False), @@ -356,7 +356,7 @@ class ResNet(nn.Module): self.style = style self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg - self.normalize = normalize + self.norm_cfg = norm_cfg self.with_cp = with_cp self.norm_eval = norm_eval self.dcn = dcn @@ -386,7 +386,7 @@ class ResNet(nn.Module): style=self.style, with_cp=with_cp, conv_cfg=conv_cfg, - normalize=normalize, + norm_cfg=norm_cfg, dcn=dcn) self.inplanes = planes * self.block.expansion layer_name = 'layer{}'.format(i + 1) @@ -411,8 +411,7 @@ class ResNet(nn.Module): stride=2, padding=3, bias=False) - self.norm1_name, norm1 = build_norm_layer( - self.normalize, 64, postfix=1) + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py index 3cdb1cb26e49bf011b6cc7a47f3a4ef02427b325..c869a02ebc93e0a40d66ef4c13d6aab7c0a5bb4f 100644 --- a/mmdet/models/backbones/resnext.py +++ b/mmdet/models/backbones/resnext.py @@ -24,11 +24,11 @@ class Bottleneck(_Bottleneck): width = math.floor(self.planes * (base_width / 64)) * groups self.norm1_name, norm1 = build_norm_layer( - self.normalize, width, postfix=1) + self.norm_cfg, width, postfix=1) self.norm2_name, norm2 = build_norm_layer( - self.normalize, width, postfix=2) + self.norm_cfg, width, postfix=2) self.norm3_name, norm3 = build_norm_layer( - self.normalize, self.planes * self.expansion, postfix=3) + self.norm_cfg, self.planes * self.expansion, postfix=3) self.conv1 = build_conv_layer( self.conv_cfg, @@ -102,7 +102,7 @@ def make_res_layer(block, style='pytorch', with_cp=False, conv_cfg=None, - normalize=dict(type='BN'), + norm_cfg=dict(type='BN'), dcn=None): downsample = None if stride != 1 or inplanes != planes * block.expansion: @@ -114,7 +114,7 @@ def make_res_layer(block, kernel_size=1, stride=stride, bias=False), - build_norm_layer(normalize, planes * block.expansion)[1], + build_norm_layer(norm_cfg, planes * block.expansion)[1], ) layers = [] @@ -130,7 +130,7 @@ def make_res_layer(block, style=style, with_cp=with_cp, conv_cfg=conv_cfg, - normalize=normalize, + norm_cfg=norm_cfg, dcn=dcn)) inplanes = planes * block.expansion for i in range(1, blocks): @@ -145,7 +145,7 @@ def make_res_layer(block, style=style, with_cp=with_cp, conv_cfg=conv_cfg, - normalize=normalize, + norm_cfg=norm_cfg, dcn=dcn)) return nn.Sequential(*layers) @@ -168,7 +168,7 @@ class ResNeXt(ResNet): the first 1x1 conv layer. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. - normalize (dict): dictionary to construct and config norm layer. + norm_cfg (dict): dictionary to construct and config norm layer. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. @@ -208,7 +208,7 @@ class ResNeXt(ResNet): style=self.style, with_cp=self.with_cp, conv_cfg=self.conv_cfg, - normalize=self.normalize, + norm_cfg=self.norm_cfg, dcn=dcn) self.inplanes = planes * self.block.expansion layer_name = 'layer{}'.format(i + 1) diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py index af2a5e34907b9a2c7ba865f7916fd592b5b0d54b..c0e715415f8b8eb5f519dd25e22d9df757d299c8 100644 --- a/mmdet/models/bbox_heads/convfc_bbox_head.py +++ b/mmdet/models/bbox_heads/convfc_bbox_head.py @@ -25,7 +25,7 @@ class ConvFCBBoxHead(BBoxHead): conv_out_channels=256, fc_out_channels=1024, conv_cfg=None, - normalize=None, + norm_cfg=None, *args, **kwargs): super(ConvFCBBoxHead, self).__init__(*args, **kwargs) @@ -46,8 +46,7 @@ class ConvFCBBoxHead(BBoxHead): self.conv_out_channels = conv_out_channels self.fc_out_channels = fc_out_channels self.conv_cfg = conv_cfg - self.normalize = normalize - self.with_bias = normalize is None + self.norm_cfg = norm_cfg # add shared convs and fcs self.shared_convs, self.shared_fcs, last_layer_dim = \ @@ -104,8 +103,7 @@ class ConvFCBBoxHead(BBoxHead): 3, padding=1, conv_cfg=self.conv_cfg, - normalize=self.normalize, - bias=self.with_bias)) + norm_cfg=self.norm_cfg)) last_layer_dim = self.conv_out_channels # add branch specific fc layers branch_fcs = nn.ModuleList() diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py index 614497a8e288bc8afe8d5d96c67d704e620f0e91..9274e7c8cea60d3c3d5a606b5e9f3207aeacb7a5 100644 --- a/mmdet/models/mask_heads/fcn_mask_head.py +++ b/mmdet/models/mask_heads/fcn_mask_head.py @@ -23,7 +23,7 @@ class FCNMaskHead(nn.Module): num_classes=81, class_agnostic=False, conv_cfg=None, - normalize=None): + norm_cfg=None): super(FCNMaskHead, self).__init__() if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']: raise ValueError( @@ -39,8 +39,7 @@ class FCNMaskHead(nn.Module): self.num_classes = num_classes self.class_agnostic = class_agnostic self.conv_cfg = conv_cfg - self.normalize = normalize - self.with_bias = normalize is None + self.norm_cfg = norm_cfg self.convs = nn.ModuleList() for i in range(self.num_convs): @@ -54,8 +53,7 @@ class FCNMaskHead(nn.Module): self.conv_kernel_size, padding=padding, conv_cfg=conv_cfg, - normalize=normalize, - bias=self.with_bias)) + norm_cfg=norm_cfg)) upsample_in_channels = (self.conv_out_channels if self.num_convs > 0 else in_channels) if self.upsample_method is None: diff --git a/mmdet/models/mask_heads/fused_semantic_head.py b/mmdet/models/mask_heads/fused_semantic_head.py index f24adf357eb0d8a18ab215066a33fb8f264343dd..12b8302ab11507234e3a78146f77a0ba67f4e242 100644 --- a/mmdet/models/mask_heads/fused_semantic_head.py +++ b/mmdet/models/mask_heads/fused_semantic_head.py @@ -31,7 +31,7 @@ class FusedSemanticHead(nn.Module): ignore_label=255, loss_weight=0.2, conv_cfg=None, - normalize=None): + norm_cfg=None): super(FusedSemanticHead, self).__init__() self.num_ins = num_ins self.fusion_level = fusion_level @@ -42,8 +42,7 @@ class FusedSemanticHead(nn.Module): self.ignore_label = ignore_label self.loss_weight = loss_weight self.conv_cfg = conv_cfg - self.normalize = normalize - self.with_bias = normalize is None + self.norm_cfg = norm_cfg self.lateral_convs = nn.ModuleList() for i in range(self.num_ins): @@ -53,8 +52,7 @@ class FusedSemanticHead(nn.Module): self.in_channels, 1, conv_cfg=self.conv_cfg, - normalize=self.normalize, - bias=self.with_bias, + norm_cfg=self.norm_cfg, inplace=False)) self.convs = nn.ModuleList() @@ -67,15 +65,13 @@ class FusedSemanticHead(nn.Module): 3, padding=1, conv_cfg=self.conv_cfg, - normalize=self.normalize, - bias=self.with_bias)) + norm_cfg=self.norm_cfg)) self.conv_embedding = ConvModule( conv_out_channels, conv_out_channels, 1, conv_cfg=self.conv_cfg, - normalize=self.normalize, - bias=self.with_bias) + norm_cfg=self.norm_cfg) self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1) self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label) diff --git a/mmdet/models/mask_heads/htc_mask_head.py b/mmdet/models/mask_heads/htc_mask_head.py index 21f3130494c3e1ed145c0d42d52565f837d4e18f..9ba3ed72298a36ed63a185f581de889b72324f8d 100644 --- a/mmdet/models/mask_heads/htc_mask_head.py +++ b/mmdet/models/mask_heads/htc_mask_head.py @@ -13,8 +13,7 @@ class HTCMaskHead(FCNMaskHead): self.conv_out_channels, 1, conv_cfg=self.conv_cfg, - normalize=self.normalize, - bias=self.with_bias) + norm_cfg=self.norm_cfg) def init_weights(self): super(HTCMaskHead, self).init_weights() diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py index 3e49fc4943994dd110f3f63c98b550475c228b44..7b33b69de7de8432c7b1e40ec737b1cdc63a28e8 100644 --- a/mmdet/models/necks/fpn.py +++ b/mmdet/models/necks/fpn.py @@ -18,7 +18,7 @@ class FPN(nn.Module): add_extra_convs=False, extra_convs_on_inputs=True, conv_cfg=None, - normalize=None, + norm_cfg=None, activation=None): super(FPN, self).__init__() assert isinstance(in_channels, list) @@ -27,7 +27,6 @@ class FPN(nn.Module): self.num_ins = len(in_channels) self.num_outs = num_outs self.activation = activation - self.with_bias = normalize is None if end_level == -1: self.backbone_end_level = self.num_ins @@ -51,8 +50,7 @@ class FPN(nn.Module): out_channels, 1, conv_cfg=conv_cfg, - normalize=normalize, - bias=self.with_bias, + norm_cfg=norm_cfg, activation=self.activation, inplace=False) fpn_conv = ConvModule( @@ -61,8 +59,7 @@ class FPN(nn.Module): 3, padding=1, conv_cfg=conv_cfg, - normalize=normalize, - bias=self.with_bias, + norm_cfg=norm_cfg, activation=self.activation, inplace=False) @@ -83,8 +80,8 @@ class FPN(nn.Module): 3, stride=2, padding=1, - normalize=normalize, - bias=self.with_bias, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, activation=self.activation, inplace=False) self.fpn_convs.append(extra_fpn_conv) diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py index ea306e54559c42d768693357caddd649d9fa1980..743c2eefe60cf79008284fae218a4e41c8630521 100644 --- a/mmdet/models/shared_heads/res_layer.py +++ b/mmdet/models/shared_heads/res_layer.py @@ -17,13 +17,13 @@ class ResLayer(nn.Module): stride=2, dilation=1, style='pytorch', - normalize=dict(type='BN', requires_grad=True), + norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, with_cp=False, dcn=None): super(ResLayer, self).__init__() self.norm_eval = norm_eval - self.normalize = normalize + self.norm_cfg = norm_cfg self.stage = stage block, stage_blocks = ResNet.arch_settings[depth] stage_block = stage_blocks[stage] @@ -39,7 +39,7 @@ class ResLayer(nn.Module): dilation=dilation, style=style, with_cp=with_cp, - normalize=self.normalize, + norm_cfg=self.norm_cfg, dcn=dcn) self.add_module('layer{}'.format(stage + 1), res_layer) diff --git a/mmdet/models/utils/conv_module.py b/mmdet/models/utils/conv_module.py index b3bf9c77ca0b72223d6ac817d8e4122be6a811fa..329b2352767949eab9c1e3d5adda52e1f23558e5 100644 --- a/mmdet/models/utils/conv_module.py +++ b/mmdet/models/utils/conv_module.py @@ -42,6 +42,27 @@ def build_conv_layer(cfg, *args, **kwargs): class ConvModule(nn.Module): + """Conv-Norm-Activation block. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + conv_cfg (dict): Config dict for convolution layer. + norm_cfg (dict): Config dict for normalization layer. + activation (str or None): Activation type, "ReLU" by default. + inplace (bool): Whether to use inplace mode for activation. + activate_last (bool): Whether to apply the activation layer in the + last. (Do not use this flag since the behavior and api may be + changed in the future.) + """ def __init__(self, in_channels, @@ -51,35 +72,42 @@ class ConvModule(nn.Module): padding=0, dilation=1, groups=1, - bias=True, + bias='auto', conv_cfg=None, - normalize=None, + norm_cfg=None, activation='relu', inplace=True, activate_last=True): super(ConvModule, self).__init__() assert conv_cfg is None or isinstance(conv_cfg, dict) - assert normalize is None or isinstance(normalize, dict) - self.with_norm = normalize is not None - self.with_activatation = activation is not None - self.with_bias = bias + assert norm_cfg is None or isinstance(norm_cfg, dict) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg self.activation = activation + self.inplace = inplace self.activate_last = activate_last + self.with_norm = norm_cfg is not None + self.with_activatation = activation is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == 'auto': + bias = False if self.with_norm else True + self.with_bias = bias + if self.with_norm and self.with_bias: warnings.warn('ConvModule has norm and bias at the same time') - self.conv = build_conv_layer( - conv_cfg, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias=bias) - + # build convolution layer + self.conv = build_conv_layer(conv_cfg, + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + # export the attributes of self.conv to a higher level for convenience self.in_channels = self.conv.in_channels self.out_channels = self.conv.out_channels self.kernel_size = self.conv.kernel_size @@ -90,17 +118,21 @@ class ConvModule(nn.Module): self.output_padding = self.conv.output_padding self.groups = self.conv.groups + # build normalization layers if self.with_norm: norm_channels = out_channels if self.activate_last else in_channels - self.norm_name, norm = build_norm_layer(normalize, norm_channels) + self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) self.add_module(self.norm_name, norm) + # build activation layer if self.with_activatation: - assert activation in ['relu'], 'Only ReLU supported.' + if self.activation not in ['relu']: + raise ValueError('{} is currently not supported.'.format( + self.activation)) if self.activation == 'relu': self.activate = nn.ReLU(inplace=inplace) - # Default using msra init + # Use msra init by default self.init_weights() @property @@ -121,6 +153,7 @@ class ConvModule(nn.Module): if activate and self.with_activatation: x = self.activate(x) else: + # WARN: this may be removed or modified if norm and self.with_norm: x = self.norm(x) if activate and self.with_activatation: