From c5c7ef909dad2ac40f45acdc169e8b908b0fc175 Mon Sep 17 00:00:00 2001
From: Ligeng Zhu <Lyken17@users.noreply.github.com>
Date: Wed, 4 Sep 2019 23:20:09 -0400
Subject: [PATCH] [DeformConv] Fix zero outputs when not running on cuda:0
 (#1326)

* Update deform_conv_cuda.cpp

* Update deform_pool_cuda.cpp
---
 mmdet/ops/dcn/src/deform_conv_cuda.cpp | 12 +++++++++---
 mmdet/ops/dcn/src/deform_pool_cuda.cpp |  5 ++++-
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/mmdet/ops/dcn/src/deform_conv_cuda.cpp b/mmdet/ops/dcn/src/deform_conv_cuda.cpp
index c4563ed..54dc8e4 100644
--- a/mmdet/ops/dcn/src/deform_conv_cuda.cpp
+++ b/mmdet/ops/dcn/src/deform_conv_cuda.cpp
@@ -2,6 +2,7 @@
 // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
 
 #include <torch/extension.h>
+#include <ATen/DeviceGuard.h>
 
 #include <cmath>
 #include <vector>
@@ -162,7 +163,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
 
   shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
               dilationH, dilationW, group, deformable_group);
-
+  at::DeviceGuard guard(input.device());
+  
   input = input.contiguous();
   offset = offset.contiguous();
   weight = weight.contiguous();
@@ -266,6 +268,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
                                     int deformable_group, int im2col_step) {
   shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
               dilationH, dilationW, group, deformable_group);
+  at::DeviceGuard guard(input.device());
 
   input = input.contiguous();
   offset = offset.contiguous();
@@ -382,7 +385,8 @@ int deform_conv_backward_parameters_cuda(
 
   shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
               padW, dilationH, dilationW, group, deformable_group);
-
+  at::DeviceGuard guard(input.device());
+  
   input = input.contiguous();
   offset = offset.contiguous();
   gradOutput = gradOutput.contiguous();
@@ -492,7 +496,8 @@ void modulated_deform_conv_cuda_forward(
     const bool with_bias) {
   AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
   AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
-
+  at::DeviceGuard guard(input.device());
+  
   const int batch = input.size(0);
   const int channels = input.size(1);
   const int height = input.size(2);
@@ -573,6 +578,7 @@ void modulated_deform_conv_cuda_backward(
     const bool with_bias) {
   AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
   AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+  at::DeviceGuard guard(input.device());
 
   const int batch = input.size(0);
   const int channels = input.size(1);
diff --git a/mmdet/ops/dcn/src/deform_pool_cuda.cpp b/mmdet/ops/dcn/src/deform_pool_cuda.cpp
index 803d5f1..9e0e3ff 100644
--- a/mmdet/ops/dcn/src/deform_pool_cuda.cpp
+++ b/mmdet/ops/dcn/src/deform_pool_cuda.cpp
@@ -6,6 +6,7 @@
 // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
 
 #include <torch/extension.h>
+#include <ATen/DeviceGuard.h>
 
 #include <cmath>
 #include <vector>
@@ -33,6 +34,7 @@ void deform_psroi_pooling_cuda_forward(
     const int output_dim, const int group_size, const int pooled_size,
     const int part_size, const int sample_per_part, const float trans_std) {
   AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  at::DeviceGuard guard(input.device());
 
   const int batch = input.size(0);
   const int channels = input.size(1);
@@ -59,6 +61,7 @@ void deform_psroi_pooling_cuda_backward(
     const int sample_per_part, const float trans_std) {
   AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
   AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  at::DeviceGuard guard(input.device());
 
   const int batch = input.size(0);
   const int channels = input.size(1);
@@ -84,4 +87,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   m.def("deform_psroi_pooling_cuda_backward",
         &deform_psroi_pooling_cuda_backward,
         "deform psroi pooling backward(CUDA)");
-}
\ No newline at end of file
+}
-- 
GitLab