diff --git a/mmdet/ops/dcn/__init__.py b/mmdet/ops/dcn/__init__.py index 1e158d01b04fbc01278021de19ed641bc6fde414..165e63725354de429a448d866f665cccca991916 100644 --- a/mmdet/ops/dcn/__init__.py +++ b/mmdet/ops/dcn/__init__.py @@ -1,13 +1,13 @@ from .functions.deform_conv import deform_conv, modulated_deform_conv from .functions.deform_pool import deform_roi_pooling from .modules.deform_conv import (DeformConv, ModulatedDeformConv, - ModulatedDeformConvPack) + DeformConvPack, ModulatedDeformConvPack) from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack) __all__ = [ - 'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack', - 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', - 'ModulatedDeformConvPack', 'deform_conv', - 'modulated_deform_conv', 'deform_roi_pooling' + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', + 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', + 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv', + 'deform_roi_pooling' ] diff --git a/mmdet/ops/dcn/modules/deform_conv.py b/mmdet/ops/dcn/modules/deform_conv.py index e1b0a63c7a0b65ecf91362c5c2469d94635cc82c..016ad16647257f9497b2bdc6fc192eb7ada839e5 100644 --- a/mmdet/ops/dcn/modules/deform_conv.py +++ b/mmdet/ops/dcn/modules/deform_conv.py @@ -19,15 +19,16 @@ class DeformConv(nn.Module): groups=1, deformable_groups=1, bias=False): - assert not bias super(DeformConv, self).__init__() + assert not bias assert in_channels % groups == 0, \ 'in_channels {} cannot be divisible by groups {}'.format( in_channels, groups) assert out_channels % groups == 0, \ 'out_channels {} cannot be divisible by groups {}'.format( out_channels, groups) + self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) @@ -50,10 +51,34 @@ class DeformConv(nn.Module): stdv = 1. / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) - def forward(self, input, offset): - return deform_conv(input, offset, self.weight, self.stride, - self.padding, self.dilation, self.groups, - self.deformable_groups) + def forward(self, x, offset): + return deform_conv(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +class DeformConvPack(DeformConv): + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) class ModulatedDeformConv(nn.Module): @@ -97,30 +122,19 @@ class ModulatedDeformConv(nn.Module): if self.bias is not None: self.bias.data.zero_() - def forward(self, input, offset, mask): - return modulated_deform_conv( - input, offset, mask, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups, self.deformable_groups) + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) class ModulatedDeformConvPack(ModulatedDeformConv): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - deformable_groups=1, - bias=True): - super(ModulatedDeformConvPack, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, deformable_groups, bias) + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) self.conv_offset_mask = nn.Conv2d( - self.in_channels // self.groups, + self.in_channels, self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], kernel_size=self.kernel_size, @@ -133,11 +147,11 @@ class ModulatedDeformConvPack(ModulatedDeformConv): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() - def forward(self, input): - out = self.conv_offset_mask(input) + def forward(self, x): + out = self.conv_offset_mask(x) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) - return modulated_deform_conv( - input, offset, mask, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups, self.deformable_groups) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups)