diff --git a/mmdet/ops/dcn/modules/deform_pool.py b/mmdet/ops/dcn/modules/deform_pool.py index b7f6a2843bad14dfa4c70f7121e640fab7ffe67c..5e0196753ee1b427263bc397e0ae842af6a9938b 100644 --- a/mmdet/ops/dcn/modules/deform_pool.py +++ b/mmdet/ops/dcn/modules/deform_pool.py @@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling): part_size=None, sample_per_part=4, trans_std=.0, + num_offset_fcs=3, deform_fc_channels=1024): super(DeformRoIPoolingPack, self).__init__(spatial_scale, out_size, out_channels, no_trans, group_size, part_size, sample_per_part, trans_std) + self.num_offset_fcs = num_offset_fcs self.deform_fc_channels = deform_fc_channels if not no_trans: - self.offset_fc = nn.Sequential( - nn.Linear(self.out_size * self.out_size * self.out_channels, - self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, - self.out_size * self.out_size * 2)) + seq = [] + ic = self.out_size * self.out_size * self.out_channels + for i in range(self.num_offset_fcs): + if i < self.num_offset_fcs - 1: + oc = self.deform_fc_channels + else: + oc = self.out_size * self.out_size * 2 + seq.append(nn.Linear(ic, oc)) + ic = oc + if i < self.num_offset_fcs - 1: + seq.append(nn.ReLU(inplace=True)) + self.offset_fc = nn.Sequential(*seq) self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].bias.data.zero_() @@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): part_size=None, sample_per_part=4, trans_std=.0, + num_offset_fcs=3, + num_mask_fcs=2, deform_fc_channels=1024): super(ModulatedDeformRoIPoolingPack, self).__init__( spatial_scale, out_size, out_channels, no_trans, group_size, part_size, sample_per_part, trans_std) + self.num_offset_fcs = num_offset_fcs + self.num_mask_fcs = num_mask_fcs self.deform_fc_channels = deform_fc_channels if not no_trans: - self.offset_fc = nn.Sequential( - nn.Linear(self.out_size * self.out_size * self.out_channels, - self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, - self.out_size * self.out_size * 2)) + offset_fc_seq = [] + ic = self.out_size * self.out_size * self.out_channels + for i in range(self.num_offset_fcs): + if i < self.num_offset_fcs - 1: + oc = self.deform_fc_channels + else: + oc = self.out_size * self.out_size * 2 + offset_fc_seq.append(nn.Linear(ic, oc)) + ic = oc + if i < self.num_offset_fcs - 1: + offset_fc_seq.append(nn.ReLU(inplace=True)) + self.offset_fc = nn.Sequential(*offset_fc_seq) self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].bias.data.zero_() - self.mask_fc = nn.Sequential( - nn.Linear(self.out_size * self.out_size * self.out_channels, - self.deform_fc_channels), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_channels, - self.out_size * self.out_size * 1), - nn.Sigmoid()) - self.mask_fc[2].weight.data.zero_() - self.mask_fc[2].bias.data.zero_() + + mask_fc_seq = [] + ic = self.out_size * self.out_size * self.out_channels + for i in range(self.num_mask_fcs): + if i < self.num_mask_fcs - 1: + oc = self.deform_fc_channels + else: + oc = self.out_size * self.out_size + mask_fc_seq.append(nn.Linear(ic, oc)) + ic = oc + if i < self.num_mask_fcs - 1: + mask_fc_seq.append(nn.ReLU(inplace=True)) + else: + mask_fc_seq.append(nn.Sigmoid()) + self.mask_fc = nn.Sequential(*mask_fc_seq) + self.mask_fc[-2].weight.data.zero_() + self.mask_fc[-2].bias.data.zero_() def forward(self, data, rois): assert data.size(1) == self.out_channels