Skip to content
Snippets Groups Projects
Commit c64beaf1 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Fix dpool (#1390)

* fix dpool

* add _pair in dpool func
parent 69e93f6f
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from . import deform_pool_cuda
......@@ -21,6 +22,12 @@ class DeformRoIPoolingFunction(Function):
part_size=None,
sample_per_part=4,
trans_std=.0):
# TODO: support unsquare RoIs
out_h, out_w = _pair(out_size)
assert isinstance(out_h, int) and isinstance(out_w, int)
assert out_h == out_w
out_size = out_h # out_h and out_w must be equal
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
......@@ -85,7 +92,7 @@ class DeformRoIPooling(nn.Module):
trans_std=.0):
super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale
self.out_size = out_size
self.out_size = _pair(out_size)
self.out_channels = out_channels
self.no_trans = no_trans
self.group_size = group_size
......@@ -125,12 +132,12 @@ class DeformRoIPoolingPack(DeformRoIPooling):
if not no_trans:
seq = []
ic = self.out_size * self.out_size * self.out_channels
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size * 2
oc = self.out_size[0] * self.out_size[1] * 2
seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
......@@ -156,7 +163,7 @@ class DeformRoIPoolingPack(DeformRoIPooling):
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
......@@ -188,12 +195,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
if not no_trans:
offset_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size * 2
oc = self.out_size[0] * self.out_size[1] * 2
offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
......@@ -203,12 +210,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self.offset_fc[-1].bias.data.zero_()
mask_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_mask_fcs):
if i < self.num_mask_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size
oc = self.out_size[0] * self.out_size[1]
mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_mask_fcs - 1:
......@@ -236,9 +243,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size)
mask = mask.view(n, 1, self.out_size[0], self.out_size[1])
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment