diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py index 1485c7fd8ebab6cad9ad08d5333acc86a9293a23..4de176fb109a0dafe054d999a048f783b68f7cd4 100644 --- a/mmdet/models/bbox_heads/convfc_bbox_head.py +++ b/mmdet/models/bbox_heads/convfc_bbox_head.py @@ -22,6 +22,7 @@ class ConvFCBBoxHead(BBoxHead): num_reg_fcs=0, conv_out_channels=256, fc_out_channels=1024, + normalize=None, *args, **kwargs): super(ConvFCBBoxHead, self).__init__(*args, **kwargs) @@ -41,6 +42,8 @@ class ConvFCBBoxHead(BBoxHead): self.num_reg_fcs = num_reg_fcs self.conv_out_channels = conv_out_channels self.fc_out_channels = fc_out_channels + self.normalize = normalize + self.with_bias = normalize is None # add shared convs and fcs self.shared_convs, self.shared_fcs, last_layer_dim = \ diff --git a/mmdet/models/utils/conv_module.py b/mmdet/models/utils/conv_module.py index 25121972da29d8e4e83fb2301b8f8d25a1727f7e..f8e8cce9a3e4df3193bddc25272214c9720f9026 100644 --- a/mmdet/models/utils/conv_module.py +++ b/mmdet/models/utils/conv_module.py @@ -1,6 +1,7 @@ import warnings import torch.nn as nn +from mmcv.cnn import kaiming_init, constant_init from .norm import build_norm_layer @@ -51,15 +52,8 @@ class ConvModule(nn.Module): self.groups = self.conv.groups if self.with_norm: - # self.norm_type, self.norm_params = parse_norm(normalize) - # assert self.norm_type in [None, 'BN', 'SyncBN', 'GN', 'SN'] - # self.Norm2d = norm_cfg[self.norm_type] - if self.activate_last: - self.norm = build_norm_layer(normalize, out_channels) - # self.norm = self.Norm2d(out_channels, **self.norm_params) - else: - self.norm = build_norm_layer(normalize, in_channels) - # self.norm = self.Norm2d(in_channels, **self.norm_params) + norm_channels = out_channels if self.activate_last else in_channels + self.norm = build_norm_layer(normalize, norm_channels) if self.with_activatation: assert activation in ['relu'], 'Only ReLU supported.' @@ -71,13 +65,9 @@ class ConvModule(nn.Module): def init_weights(self): nonlinearity = 'relu' if self.activation is None else self.activation - nn.init.kaiming_normal_( - self.conv.weight, mode='fan_out', nonlinearity=nonlinearity) - if self.with_bias: - nn.init.constant_(self.conv.bias, 0) + kaiming_init(self.conv, nonlinearity=nonlinearity) if self.with_norm: - nn.init.constant_(self.norm.weight, 1) - nn.init.constant_(self.norm.bias, 0) + constant_init(self.norm, 1, bias=0) def forward(self, x, activate=True, norm=True): if self.activate_last: