diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py index 21c2f327d019f05fd5a32575bb0c4ad71fe39ef8..9576335a3063a1e05ccc69b78fe28c5a85c16acb 100644 --- a/mmdet/ops/__init__.py +++ b/mmdet/ops/__init__.py @@ -1,14 +1,14 @@ -from .dcn import (DeformConv, DeformRoIPooling, DeformRoIPoolingPack, - ModulatedDeformRoIPoolingPack, ModulatedDeformConv, - ModulatedDeformConvPack, deform_conv, modulated_deform_conv, - deform_roi_pooling) +from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv, + ModulatedDeformConvPack, DeformRoIPooling, + DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack, + deform_conv, modulated_deform_conv, deform_roi_pooling) from .nms import nms, soft_nms from .roi_align import RoIAlign, roi_align from .roi_pool import RoIPool, roi_pool __all__ = [ 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', - 'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack', + 'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv', 'deform_roi_pooling' diff --git a/mmdet/ops/dcn/__init__.py b/mmdet/ops/dcn/__init__.py index 1e158d01b04fbc01278021de19ed641bc6fde414..165e63725354de429a448d866f665cccca991916 100644 --- a/mmdet/ops/dcn/__init__.py +++ b/mmdet/ops/dcn/__init__.py @@ -1,13 +1,13 @@ from .functions.deform_conv import deform_conv, modulated_deform_conv from .functions.deform_pool import deform_roi_pooling from .modules.deform_conv import (DeformConv, ModulatedDeformConv, - ModulatedDeformConvPack) + DeformConvPack, ModulatedDeformConvPack) from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack) __all__ = [ - 'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack', - 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', - 'ModulatedDeformConvPack', 'deform_conv', - 'modulated_deform_conv', 'deform_roi_pooling' + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', + 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', + 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv', + 'deform_roi_pooling' ] diff --git a/mmdet/ops/dcn/modules/deform_conv.py b/mmdet/ops/dcn/modules/deform_conv.py index e1b0a63c7a0b65ecf91362c5c2469d94635cc82c..016ad16647257f9497b2bdc6fc192eb7ada839e5 100644 --- a/mmdet/ops/dcn/modules/deform_conv.py +++ b/mmdet/ops/dcn/modules/deform_conv.py @@ -19,15 +19,16 @@ class DeformConv(nn.Module): groups=1, deformable_groups=1, bias=False): - assert not bias super(DeformConv, self).__init__() + assert not bias assert in_channels % groups == 0, \ 'in_channels {} cannot be divisible by groups {}'.format( in_channels, groups) assert out_channels % groups == 0, \ 'out_channels {} cannot be divisible by groups {}'.format( out_channels, groups) + self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) @@ -50,10 +51,34 @@ class DeformConv(nn.Module): stdv = 1. / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) - def forward(self, input, offset): - return deform_conv(input, offset, self.weight, self.stride, - self.padding, self.dilation, self.groups, - self.deformable_groups) + def forward(self, x, offset): + return deform_conv(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +class DeformConvPack(DeformConv): + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) class ModulatedDeformConv(nn.Module): @@ -97,30 +122,19 @@ class ModulatedDeformConv(nn.Module): if self.bias is not None: self.bias.data.zero_() - def forward(self, input, offset, mask): - return modulated_deform_conv( - input, offset, mask, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups, self.deformable_groups) + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) class ModulatedDeformConvPack(ModulatedDeformConv): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - deformable_groups=1, - bias=True): - super(ModulatedDeformConvPack, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, deformable_groups, bias) + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) self.conv_offset_mask = nn.Conv2d( - self.in_channels // self.groups, + self.in_channels, self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], kernel_size=self.kernel_size, @@ -133,11 +147,11 @@ class ModulatedDeformConvPack(ModulatedDeformConv): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() - def forward(self, input): - out = self.conv_offset_mask(input) + def forward(self, x): + out = self.conv_offset_mask(x) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) - return modulated_deform_conv( - input, offset, mask, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups, self.deformable_groups) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/mmdet/ops/dcn/modules/deform_pool.py b/mmdet/ops/dcn/modules/deform_pool.py index b7f6a2843bad14dfa4c70f7121e640fab7ffe67c..5e0196753ee1b427263bc397e0ae842af6a9938b 100644 --- a/mmdet/ops/dcn/modules/deform_pool.py +++ b/mmdet/ops/dcn/modules/deform_pool.py @@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling): part_size=None, sample_per_part=4, trans_std=.0, + num_offset_fcs=3, deform_fc_channels=1024): super(DeformRoIPoolingPack, self).__init__(spatial_scale, out_size, out_channels, no_trans, group_size, part_size, sample_per_part, trans_std) + self.num_offset_fcs = num_offset_fcs self.deform_fc_channels = deform_fc_channels if not no_trans: - self.offset_fc = nn.Sequential( - nn.Linear(self.out_size * self.out_size * self.out_channels, - self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, - self.out_size * self.out_size * 2)) + seq = [] + ic = self.out_size * self.out_size * self.out_channels + for i in range(self.num_offset_fcs): + if i < self.num_offset_fcs - 1: + oc = self.deform_fc_channels + else: + oc = self.out_size * self.out_size * 2 + seq.append(nn.Linear(ic, oc)) + ic = oc + if i < self.num_offset_fcs - 1: + seq.append(nn.ReLU(inplace=True)) + self.offset_fc = nn.Sequential(*seq) self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].bias.data.zero_() @@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): part_size=None, sample_per_part=4, trans_std=.0, + num_offset_fcs=3, + num_mask_fcs=2, deform_fc_channels=1024): super(ModulatedDeformRoIPoolingPack, self).__init__( spatial_scale, out_size, out_channels, no_trans, group_size, part_size, sample_per_part, trans_std) + self.num_offset_fcs = num_offset_fcs + self.num_mask_fcs = num_mask_fcs self.deform_fc_channels = deform_fc_channels if not no_trans: - self.offset_fc = nn.Sequential( - nn.Linear(self.out_size * self.out_size * self.out_channels, - self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, - self.out_size * self.out_size * 2)) + offset_fc_seq = [] + ic = self.out_size * self.out_size * self.out_channels + for i in range(self.num_offset_fcs): + if i < self.num_offset_fcs - 1: + oc = self.deform_fc_channels + else: + oc = self.out_size * self.out_size * 2 + offset_fc_seq.append(nn.Linear(ic, oc)) + ic = oc + if i < self.num_offset_fcs - 1: + offset_fc_seq.append(nn.ReLU(inplace=True)) + self.offset_fc = nn.Sequential(*offset_fc_seq) self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].bias.data.zero_() - self.mask_fc = nn.Sequential( - nn.Linear(self.out_size * self.out_size * self.out_channels, - self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, - self.out_size * self.out_size * 1), - nn.Sigmoid()) - self.mask_fc[2].weight.data.zero_() - self.mask_fc[2].bias.data.zero_() + + mask_fc_seq = [] + ic = self.out_size * self.out_size * self.out_channels + for i in range(self.num_mask_fcs): + if i < self.num_mask_fcs - 1: + oc = self.deform_fc_channels + else: + oc = self.out_size * self.out_size + mask_fc_seq.append(nn.Linear(ic, oc)) + ic = oc + if i < self.num_mask_fcs - 1: + mask_fc_seq.append(nn.ReLU(inplace=True)) + else: + mask_fc_seq.append(nn.Sigmoid()) + self.mask_fc = nn.Sequential(*mask_fc_seq) + self.mask_fc[-2].weight.data.zero_() + self.mask_fc[-2].bias.data.zero_() def forward(self, data, rois): assert data.size(1) == self.out_channels