diff --git a/mmdet/models/utils/conv_module.py b/mmdet/models/utils/conv_module.py index 623d691259c8de23dda86d5795e82002e0b49ccd..ef9a538f85932402c5faa1cbf5a0176b426ee995 100644 --- a/mmdet/models/utils/conv_module.py +++ b/mmdet/models/utils/conv_module.py @@ -42,7 +42,7 @@ def build_conv_layer(cfg, *args, **kwargs): class ConvModule(nn.Module): - """Conv-Norm-Activation block. + """A conv block that contains conv/norm/activation layers. Args: in_channels (int): Same as nn.Conv2d. @@ -59,9 +59,9 @@ class ConvModule(nn.Module): norm_cfg (dict): Config dict for normalization layer. activation (str or None): Activation type, "ReLU" by default. inplace (bool): Whether to use inplace mode for activation. - activate_last (bool): Whether to apply the activation layer in the - last. (Do not use this flag since the behavior and api may be - changed in the future.) + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). """ def __init__(self, @@ -77,7 +77,7 @@ class ConvModule(nn.Module): norm_cfg=None, activation='relu', inplace=True, - activate_last=True): + order=('conv', 'norm', 'act')): super(ConvModule, self).__init__() assert conv_cfg is None or isinstance(conv_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict) @@ -85,7 +85,9 @@ class ConvModule(nn.Module): self.norm_cfg = norm_cfg self.activation = activation self.inplace = inplace - self.activate_last = activate_last + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(['conv', 'norm', 'act']) self.with_norm = norm_cfg is not None self.with_activatation = activation is not None @@ -121,12 +123,17 @@ class ConvModule(nn.Module): # build normalization layers if self.with_norm: - norm_channels = out_channels if self.activate_last else in_channels + # norm layer is after conv layer + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) self.add_module(self.norm_name, norm) # build activation layer if self.with_activatation: + # TODO: introduce `act_cfg` and supports more activation layers if self.activation not in ['relu']: raise ValueError('{} is currently not supported.'.format( self.activation)) @@ -147,17 +154,11 @@ class ConvModule(nn.Module): constant_init(self.norm, 1, bias=0) def forward(self, x, activate=True, norm=True): - if self.activate_last: - x = self.conv(x) - if norm and self.with_norm: + for layer in self.order: + if layer == 'conv': + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: x = self.norm(x) - if activate and self.with_activatation: + elif layer == 'act' and activate and self.with_activatation: x = self.activate(x) - else: - # WARN: this may be removed or modified - if norm and self.with_norm: - x = self.norm(x) - if activate and self.with_activatation: - x = self.activate(x) - x = self.conv(x) return x