Skip to content
Snippets Groups Projects
Commit d076d5c7 authored by Jiaqi Wang's avatar Jiaqi Wang Committed by Kai Chen
Browse files

fix masked_conv cuda runtime error when mask is all zero (#779)

* fix mask conv import error

* fix masked_conv cuda runtime error when mask is all zero
parent 68589d36
No related branches found
No related tags found
No related merge requests found
......@@ -30,21 +30,22 @@ class MaskedConv2dFunction(Function):
math.floor((features.size(3) + 2 * pad_w -
(kernel_h - 1) - 1) / stride_w + 1))
mask_inds = torch.nonzero(mask[0] > 0)
mask_h_idx = mask_inds[:, 0].contiguous()
mask_w_idx = mask_inds[:, 1].contiguous()
data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
mask_inds.size(0))
masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx,
mask_w_idx, kernel_h,
kernel_w, pad_h, pad_w,
data_col)
masked_output = torch.addmm(1, bias[:, None], 1,
weight.view(out_channel, -1), data_col)
output = features.new_zeros(batch_size, out_channel, out_h, out_w)
masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx,
mask_w_idx, out_h, out_w,
out_channel, output)
if mask_inds.numel() > 0:
mask_h_idx = mask_inds[:, 0].contiguous()
mask_w_idx = mask_inds[:, 1].contiguous()
data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
mask_inds.size(0))
masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx,
mask_w_idx, kernel_h,
kernel_w, pad_h, pad_w,
data_col)
masked_output = torch.addmm(1, bias[:, None], 1,
weight.view(out_channel, -1), data_col)
masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx,
mask_w_idx, out_h, out_w,
out_channel, output)
return output
@staticmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment