From d076d5c75d1e8b7a82e949116f1450cddd126d20 Mon Sep 17 00:00:00 2001 From: Jiaqi Wang <1155098160@link.cuhk.edu.hk> Date: Mon, 10 Jun 2019 17:32:31 +0800 Subject: [PATCH] fix masked_conv cuda runtime error when mask is all zero (#779) * fix mask conv import error * fix masked_conv cuda runtime error when mask is all zero --- .../ops/masked_conv/functions/masked_conv.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/mmdet/ops/masked_conv/functions/masked_conv.py b/mmdet/ops/masked_conv/functions/masked_conv.py index 41ba5a7..eed32b7 100644 --- a/mmdet/ops/masked_conv/functions/masked_conv.py +++ b/mmdet/ops/masked_conv/functions/masked_conv.py @@ -30,21 +30,22 @@ class MaskedConv2dFunction(Function): math.floor((features.size(3) + 2 * pad_w - (kernel_h - 1) - 1) / stride_w + 1)) mask_inds = torch.nonzero(mask[0] > 0) - mask_h_idx = mask_inds[:, 0].contiguous() - mask_w_idx = mask_inds[:, 1].contiguous() - data_col = features.new_zeros(in_channel * kernel_h * kernel_w, - mask_inds.size(0)) - masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx, - mask_w_idx, kernel_h, - kernel_w, pad_h, pad_w, - data_col) - - masked_output = torch.addmm(1, bias[:, None], 1, - weight.view(out_channel, -1), data_col) output = features.new_zeros(batch_size, out_channel, out_h, out_w) - masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx, - mask_w_idx, out_h, out_w, - out_channel, output) + if mask_inds.numel() > 0: + mask_h_idx = mask_inds[:, 0].contiguous() + mask_w_idx = mask_inds[:, 1].contiguous() + data_col = features.new_zeros(in_channel * kernel_h * kernel_w, + mask_inds.size(0)) + masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx, + mask_w_idx, kernel_h, + kernel_w, pad_h, pad_w, + data_col) + + masked_output = torch.addmm(1, bias[:, None], 1, + weight.view(out_channel, -1), data_col) + masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx, + mask_w_idx, out_h, out_w, + out_channel, output) return output @staticmethod -- GitLab