Skip to content
Snippets Groups Projects
Unverified Commit 810b7110 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #54 from hellock/hotfix

Bug fix for ConvFCBBoxHead arguments
parents c8cc01e8 8e098356
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,7 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -22,6 +22,7 @@ class ConvFCBBoxHead(BBoxHead):
num_reg_fcs=0, num_reg_fcs=0,
conv_out_channels=256, conv_out_channels=256,
fc_out_channels=1024, fc_out_channels=1024,
normalize=None,
*args, *args,
**kwargs): **kwargs):
super(ConvFCBBoxHead, self).__init__(*args, **kwargs) super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
...@@ -41,6 +42,8 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -41,6 +42,8 @@ class ConvFCBBoxHead(BBoxHead):
self.num_reg_fcs = num_reg_fcs self.num_reg_fcs = num_reg_fcs
self.conv_out_channels = conv_out_channels self.conv_out_channels = conv_out_channels
self.fc_out_channels = fc_out_channels self.fc_out_channels = fc_out_channels
self.normalize = normalize
self.with_bias = normalize is None
# add shared convs and fcs # add shared convs and fcs
self.shared_convs, self.shared_fcs, last_layer_dim = \ self.shared_convs, self.shared_fcs, last_layer_dim = \
......
import warnings import warnings
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import kaiming_init, constant_init
from .norm import build_norm_layer from .norm import build_norm_layer
...@@ -51,15 +52,8 @@ class ConvModule(nn.Module): ...@@ -51,15 +52,8 @@ class ConvModule(nn.Module):
self.groups = self.conv.groups self.groups = self.conv.groups
if self.with_norm: if self.with_norm:
# self.norm_type, self.norm_params = parse_norm(normalize) norm_channels = out_channels if self.activate_last else in_channels
# assert self.norm_type in [None, 'BN', 'SyncBN', 'GN', 'SN'] self.norm = build_norm_layer(normalize, norm_channels)
# self.Norm2d = norm_cfg[self.norm_type]
if self.activate_last:
self.norm = build_norm_layer(normalize, out_channels)
# self.norm = self.Norm2d(out_channels, **self.norm_params)
else:
self.norm = build_norm_layer(normalize, in_channels)
# self.norm = self.Norm2d(in_channels, **self.norm_params)
if self.with_activatation: if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.' assert activation in ['relu'], 'Only ReLU supported.'
...@@ -71,13 +65,9 @@ class ConvModule(nn.Module): ...@@ -71,13 +65,9 @@ class ConvModule(nn.Module):
def init_weights(self): def init_weights(self):
nonlinearity = 'relu' if self.activation is None else self.activation nonlinearity = 'relu' if self.activation is None else self.activation
nn.init.kaiming_normal_( kaiming_init(self.conv, nonlinearity=nonlinearity)
self.conv.weight, mode='fan_out', nonlinearity=nonlinearity)
if self.with_bias:
nn.init.constant_(self.conv.bias, 0)
if self.with_norm: if self.with_norm:
nn.init.constant_(self.norm.weight, 1) constant_init(self.norm, 1, bias=0)
nn.init.constant_(self.norm.bias, 0)
def forward(self, x, activate=True, norm=True): def forward(self, x, activate=True, norm=True):
if self.activate_last: if self.activate_last:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment