Skip to content
Snippets Groups Projects
Commit 7f9d2eb5 authored by Kai Chen's avatar Kai Chen
Browse files

fix extension to fit pytorch 0.4.1 api

parent 0e0b9246
No related branches found
No related tags found
No related merge requests found
......@@ -17,9 +17,9 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_height, const int pooled_width,
at::Tensor bottom_grad);
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ")
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERT(x.is_contiguous(), #x " must be contiguous ")
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
......
......@@ -16,9 +16,9 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int num_rois, const int pooled_h,
const int pooled_w, at::Tensor bottom_grad);
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDAtensor ")
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERT(x.is_contiguous(), #x " must be contiguous ")
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment