From 7f9d2eb5e7748ceac1fbc3274b79ae7d473ee288 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Tue, 18 Sep 2018 16:58:05 +0800
Subject: [PATCH] fix extension to fit pytorch 0.4.1 api

---
 mmdet/ops/roi_align/src/roi_align_cuda.cpp | 4 ++--
 mmdet/ops/roi_pool/src/roi_pool_cuda.cpp   | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mmdet/ops/roi_align/src/roi_align_cuda.cpp b/mmdet/ops/roi_align/src/roi_align_cuda.cpp
index e4c28c1..8551bc5 100644
--- a/mmdet/ops/roi_align/src/roi_align_cuda.cpp
+++ b/mmdet/ops/roi_align/src/roi_align_cuda.cpp
@@ -17,9 +17,9 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
                             const int pooled_height, const int pooled_width,
                             at::Tensor bottom_grad);
 
-#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ")
+#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
 #define CHECK_CONTIGUOUS(x) \
-  AT_ASSERT(x.is_contiguous(), #x " must be contiguous ")
+  AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
 #define CHECK_INPUT(x) \
   CHECK_CUDA(x);       \
   CHECK_CONTIGUOUS(x)
diff --git a/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp b/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
index 799c151..b05e870 100644
--- a/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
+++ b/mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
@@ -16,9 +16,9 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
                            const int num_rois, const int pooled_h,
                            const int pooled_w, at::Tensor bottom_grad);
 
-#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ")
+#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
 #define CHECK_CONTIGUOUS(x) \
-  AT_ASSERT(x.is_contiguous(), #x " must be contiguous ")
+  AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
 #define CHECK_INPUT(x) \
   CHECK_CUDA(x);       \
   CHECK_CONTIGUOUS(x)
-- 
GitLab