diff --git a/mmdet/ops/dcn/src/deform_conv_cuda.cpp b/mmdet/ops/dcn/src/deform_conv_cuda.cpp index c4563ed86604279983d4d1a80c2890293b38c61b..54dc8e4b7d7128d1569f76781ff88e05f1954317 100644 --- a/mmdet/ops/dcn/src/deform_conv_cuda.cpp +++ b/mmdet/ops/dcn/src/deform_conv_cuda.cpp @@ -2,6 +2,7 @@ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c #include <torch/extension.h> +#include <ATen/DeviceGuard.h> #include <cmath> #include <vector> @@ -162,7 +163,8 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, dilationH, dilationW, group, deformable_group); - + at::DeviceGuard guard(input.device()); + input = input.contiguous(); offset = offset.contiguous(); weight = weight.contiguous(); @@ -266,6 +268,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, int deformable_group, int im2col_step) { shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); input = input.contiguous(); offset = offset.contiguous(); @@ -382,7 +385,8 @@ int deform_conv_backward_parameters_cuda( shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, padW, dilationH, dilationW, group, deformable_group); - + at::DeviceGuard guard(input.device()); + input = input.contiguous(); offset = offset.contiguous(); gradOutput = gradOutput.contiguous(); @@ -492,7 +496,8 @@ void modulated_deform_conv_cuda_forward( const bool with_bias) { AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - + at::DeviceGuard guard(input.device()); + const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); @@ -573,6 +578,7 @@ void modulated_deform_conv_cuda_backward( const bool with_bias) { AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); const int batch = input.size(0); const int channels = input.size(1); diff --git a/mmdet/ops/dcn/src/deform_pool_cuda.cpp b/mmdet/ops/dcn/src/deform_pool_cuda.cpp index 803d5f1499049a8368ff7af7a6451c859ce9d310..9e0e3ffcb547a8866311699ee6481b743284dffe 100644 --- a/mmdet/ops/dcn/src/deform_pool_cuda.cpp +++ b/mmdet/ops/dcn/src/deform_pool_cuda.cpp @@ -6,6 +6,7 @@ // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu #include <torch/extension.h> +#include <ATen/DeviceGuard.h> #include <cmath> #include <vector> @@ -33,6 +34,7 @@ void deform_psroi_pooling_cuda_forward( const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); const int batch = input.size(0); const int channels = input.size(1); @@ -59,6 +61,7 @@ void deform_psroi_pooling_cuda_backward( const int sample_per_part, const float trans_std) { AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); const int batch = input.size(0); const int channels = input.size(1); @@ -84,4 +87,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("deform_psroi_pooling_cuda_backward", &deform_psroi_pooling_cuda_backward, "deform psroi pooling backward(CUDA)"); -} \ No newline at end of file +}