Skip to content
Snippets Groups Projects
Unverified Commit 64b1c8b6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #470 from hellock/dcn-api

Add some arguments to DCN ops
parents e421e832 c35a11b6
No related branches found
No related tags found
No related merge requests found
from .dcn import (DeformConv, DeformRoIPooling, DeformRoIPoolingPack,
ModulatedDeformRoIPoolingPack, ModulatedDeformConv,
ModulatedDeformConvPack, deform_conv, modulated_deform_conv,
deform_roi_pooling)
from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
ModulatedDeformConvPack, DeformRoIPooling,
DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack,
deform_conv, modulated_deform_conv, deform_roi_pooling)
from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool
__all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling'
......
from .functions.deform_conv import deform_conv, modulated_deform_conv
from .functions.deform_pool import deform_roi_pooling
from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
ModulatedDeformConvPack)
DeformConvPack, ModulatedDeformConvPack)
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
ModulatedDeformRoIPoolingPack)
__all__ = [
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling'
]
......@@ -19,15 +19,16 @@ class DeformConv(nn.Module):
groups=1,
deformable_groups=1,
bias=False):
assert not bias
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
......@@ -50,10 +51,34 @@ class DeformConv(nn.Module):
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)
def forward(self, x, offset):
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class DeformConvPack(DeformConv):
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConv(nn.Module):
......@@ -97,30 +122,19 @@ class ModulatedDeformConv(nn.Module):
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset_mask = nn.Conv2d(
self.in_channels // self.groups,
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
......@@ -133,11 +147,11 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
def forward(self, x):
out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
......@@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling):
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
deform_fc_channels=1024):
super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
seq = []
ic = self.out_size * self.out_size * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size * 2
seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
......@@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
num_mask_fcs=2,
deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.num_mask_fcs = num_mask_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
offset_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size * 2
offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
offset_fc_seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*offset_fc_seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
mask_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels
for i in range(self.num_mask_fcs):
if i < self.num_mask_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size
mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_mask_fcs - 1:
mask_fc_seq.append(nn.ReLU(inplace=True))
else:
mask_fc_seq.append(nn.Sigmoid())
self.mask_fc = nn.Sequential(*mask_fc_seq)
self.mask_fc[-2].weight.data.zero_()
self.mask_fc[-2].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment