diff --git a/mmdet/ops/dcn/deform_pool.py b/mmdet/ops/dcn/deform_pool.py index 2d09dec8c82fd01fc947e7258b5bbb54e3245d43..99a4a3618971e024c31cb3e3bf90035fe4759509 100644 --- a/mmdet/ops/dcn/deform_pool.py +++ b/mmdet/ops/dcn/deform_pool.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torch.autograd import Function from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair from . import deform_pool_cuda @@ -21,6 +22,12 @@ class DeformRoIPoolingFunction(Function): part_size=None, sample_per_part=4, trans_std=.0): + # TODO: support unsquare RoIs + out_h, out_w = _pair(out_size) + assert isinstance(out_h, int) and isinstance(out_w, int) + assert out_h == out_w + out_size = out_h # out_h and out_w must be equal + ctx.spatial_scale = spatial_scale ctx.out_size = out_size ctx.out_channels = out_channels @@ -85,7 +92,7 @@ class DeformRoIPooling(nn.Module): trans_std=.0): super(DeformRoIPooling, self).__init__() self.spatial_scale = spatial_scale - self.out_size = out_size + self.out_size = _pair(out_size) self.out_channels = out_channels self.no_trans = no_trans self.group_size = group_size @@ -125,12 +132,12 @@ class DeformRoIPoolingPack(DeformRoIPooling): if not no_trans: seq = [] - ic = self.out_size * self.out_size * self.out_channels + ic = self.out_size[0] * self.out_size[1] * self.out_channels for i in range(self.num_offset_fcs): if i < self.num_offset_fcs - 1: oc = self.deform_fc_channels else: - oc = self.out_size * self.out_size * 2 + oc = self.out_size[0] * self.out_size[1] * 2 seq.append(nn.Linear(ic, oc)) ic = oc if i < self.num_offset_fcs - 1: @@ -156,7 +163,7 @@ class DeformRoIPoolingPack(DeformRoIPooling): self.group_size, self.part_size, self.sample_per_part, self.trans_std) offset = self.offset_fc(x.view(n, -1)) - offset = offset.view(n, 2, self.out_size, self.out_size) + offset = offset.view(n, 2, self.out_size[0], self.out_size[1]) return deform_roi_pooling(data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, self.no_trans, self.group_size, @@ -188,12 +195,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): if not no_trans: offset_fc_seq = [] - ic = self.out_size * self.out_size * self.out_channels + ic = self.out_size[0] * self.out_size[1] * self.out_channels for i in range(self.num_offset_fcs): if i < self.num_offset_fcs - 1: oc = self.deform_fc_channels else: - oc = self.out_size * self.out_size * 2 + oc = self.out_size[0] * self.out_size[1] * 2 offset_fc_seq.append(nn.Linear(ic, oc)) ic = oc if i < self.num_offset_fcs - 1: @@ -203,12 +210,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): self.offset_fc[-1].bias.data.zero_() mask_fc_seq = [] - ic = self.out_size * self.out_size * self.out_channels + ic = self.out_size[0] * self.out_size[1] * self.out_channels for i in range(self.num_mask_fcs): if i < self.num_mask_fcs - 1: oc = self.deform_fc_channels else: - oc = self.out_size * self.out_size + oc = self.out_size[0] * self.out_size[1] mask_fc_seq.append(nn.Linear(ic, oc)) ic = oc if i < self.num_mask_fcs - 1: @@ -236,9 +243,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): self.group_size, self.part_size, self.sample_per_part, self.trans_std) offset = self.offset_fc(x.view(n, -1)) - offset = offset.view(n, 2, self.out_size, self.out_size) + offset = offset.view(n, 2, self.out_size[0], self.out_size[1]) mask = self.mask_fc(x.view(n, -1)) - mask = mask.view(n, 1, self.out_size, self.out_size) + mask = mask.view(n, 1, self.out_size[0], self.out_size[1]) return deform_roi_pooling( data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, self.no_trans, self.group_size,