Skip to content
Snippets Groups Projects
Commit 5fcec9ae authored by Kai Chen's avatar Kai Chen
Browse files

allow specifying num_offset_fcs and num_mask_fcs

parent 527629fe
No related branches found
No related tags found
No related merge requests found
...@@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling):
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
num_offset_fcs=3,
deform_fc_channels=1024): deform_fc_channels=1024):
super(DeformRoIPoolingPack, super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, out_channels, no_trans, self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std) group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( seq = []
nn.Linear(self.out_size * self.out_size * self.out_channels, ic = self.out_size * self.out_size * self.out_channels
self.deform_fc_channels), for i in range(self.num_offset_fcs):
nn.ReLU(inplace=True), if i < self.num_offset_fcs - 1:
nn.Linear(self.deform_fc_channels, self.deform_fc_channels), oc = self.deform_fc_channels
nn.ReLU(inplace=True), else:
nn.Linear(self.deform_fc_channels, oc = self.out_size * self.out_size * 2
self.out_size * self.out_size * 2)) seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*seq)
self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
...@@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
num_offset_fcs=3,
num_mask_fcs=2,
deform_fc_channels=1024): deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__( super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, out_channels, no_trans, group_size, spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std) part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.num_mask_fcs = num_mask_fcs
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( offset_fc_seq = []
nn.Linear(self.out_size * self.out_size * self.out_channels, ic = self.out_size * self.out_size * self.out_channels
self.deform_fc_channels), for i in range(self.num_offset_fcs):
nn.ReLU(inplace=True), if i < self.num_offset_fcs - 1:
nn.Linear(self.deform_fc_channels, self.deform_fc_channels), oc = self.deform_fc_channels
nn.ReLU(inplace=True), else:
nn.Linear(self.deform_fc_channels, oc = self.out_size * self.out_size * 2
self.out_size * self.out_size * 2)) offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
offset_fc_seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*offset_fc_seq)
self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels, mask_fc_seq = []
self.deform_fc_channels), ic = self.out_size * self.out_size * self.out_channels
nn.ReLU(inplace=True), for i in range(self.num_mask_fcs):
nn.Linear(self.deform_fc_channels, if i < self.num_mask_fcs - 1:
self.out_size * self.out_size * 1), oc = self.deform_fc_channels
nn.Sigmoid()) else:
self.mask_fc[2].weight.data.zero_() oc = self.out_size * self.out_size
self.mask_fc[2].bias.data.zero_() mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_mask_fcs - 1:
mask_fc_seq.append(nn.ReLU(inplace=True))
else:
mask_fc_seq.append(nn.Sigmoid())
self.mask_fc = nn.Sequential(*mask_fc_seq)
self.mask_fc[-2].weight.data.zero_()
self.mask_fc[-2].bias.data.zero_()
def forward(self, data, rois): def forward(self, data, rois):
assert data.size(1) == self.out_channels assert data.size(1) == self.out_channels
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment