From 7f9d2eb5e7748ceac1fbc3274b79ae7d473ee288 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Tue, 18 Sep 2018 16:58:05 +0800 Subject: [PATCH] fix extension to fit pytorch 0.4.1 api --- mmdet/ops/roi_align/src/roi_align_cuda.cpp | 4 ++-- mmdet/ops/roi_pool/src/roi_pool_cuda.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mmdet/ops/roi_align/src/roi_align_cuda.cpp b/mmdet/ops/roi_align/src/roi_align_cuda.cpp index e4c28c1..8551bc5 100644 --- a/mmdet/ops/roi_align/src/roi_align_cuda.cpp +++ b/mmdet/ops/roi_align/src/roi_align_cuda.cpp @@ -17,9 +17,9 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, const int pooled_height, const int pooled_width, at::Tensor bottom_grad); -#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ") +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CONTIGUOUS(x) \ - AT_ASSERT(x.is_contiguous(), #x " must be contiguous ") + AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) diff --git a/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp b/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp index 799c151..b05e870 100644 --- a/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp +++ b/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp @@ -16,9 +16,9 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, const int num_rois, const int pooled_h, const int pooled_w, at::Tensor bottom_grad); -#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ") +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CONTIGUOUS(x) \ - AT_ASSERT(x.is_contiguous(), #x " must be contiguous ") + AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) -- GitLab