From d076d5c75d1e8b7a82e949116f1450cddd126d20 Mon Sep 17 00:00:00 2001
From: Jiaqi Wang <1155098160@link.cuhk.edu.hk>
Date: Mon, 10 Jun 2019 17:32:31 +0800
Subject: [PATCH]  fix masked_conv cuda runtime error when mask is all zero
 (#779)

* fix mask conv import error

* fix masked_conv cuda runtime error when mask is all zero
---
 .../ops/masked_conv/functions/masked_conv.py  | 29 ++++++++++---------
 1 file changed, 15 insertions(+), 14 deletions(-)

diff --git a/mmdet/ops/masked_conv/functions/masked_conv.py b/mmdet/ops/masked_conv/functions/masked_conv.py
index 41ba5a7..eed32b7 100644
--- a/mmdet/ops/masked_conv/functions/masked_conv.py
+++ b/mmdet/ops/masked_conv/functions/masked_conv.py
@@ -30,21 +30,22 @@ class MaskedConv2dFunction(Function):
             math.floor((features.size(3) + 2 * pad_w -
                         (kernel_h - 1) - 1) / stride_w + 1))
         mask_inds = torch.nonzero(mask[0] > 0)
-        mask_h_idx = mask_inds[:, 0].contiguous()
-        mask_w_idx = mask_inds[:, 1].contiguous()
-        data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
-                                      mask_inds.size(0))
-        masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx,
-                                                 mask_w_idx, kernel_h,
-                                                 kernel_w, pad_h, pad_w,
-                                                 data_col)
-
-        masked_output = torch.addmm(1, bias[:, None], 1,
-                                    weight.view(out_channel, -1), data_col)
         output = features.new_zeros(batch_size, out_channel, out_h, out_w)
-        masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx,
-                                                 mask_w_idx, out_h, out_w,
-                                                 out_channel, output)
+        if mask_inds.numel() > 0:
+            mask_h_idx = mask_inds[:, 0].contiguous()
+            mask_w_idx = mask_inds[:, 1].contiguous()
+            data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
+                                          mask_inds.size(0))
+            masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx,
+                                                     mask_w_idx, kernel_h,
+                                                     kernel_w, pad_h, pad_w,
+                                                     data_col)
+
+            masked_output = torch.addmm(1, bias[:, None], 1,
+                                        weight.view(out_channel, -1), data_col)
+            masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx,
+                                                     mask_w_idx, out_h, out_w,
+                                                     out_channel, output)
         return output
 
     @staticmethod
-- 
GitLab