diff --git a/mmdet/ops/roi_align/functions/roi_align.py b/mmdet/ops/roi_align/functions/roi_align.py
index 096badd25673b6b46b3ccb36952e021d00cab835..cd2ee9edd10896f2b1728b28b1d47d4883176980 100644
--- a/mmdet/ops/roi_align/functions/roi_align.py
+++ b/mmdet/ops/roi_align/functions/roi_align.py
@@ -1,4 +1,5 @@
 from torch.autograd import Function
+from torch.nn.modules.utils import _pair
 
 from .. import roi_align_cuda
 
@@ -7,17 +8,8 @@ class RoIAlignFunction(Function):
 
     @staticmethod
     def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0):
-        if isinstance(out_size, int):
-            out_h = out_size
-            out_w = out_size
-        elif isinstance(out_size, tuple):
-            assert len(out_size) == 2
-            assert isinstance(out_size[0], int)
-            assert isinstance(out_size[1], int)
-            out_h, out_w = out_size
-        else:
-            raise TypeError(
-                '"out_size" must be an integer or tuple of integers')
+        out_h, out_w = _pair(out_size)
+        assert isinstance(out_h, int) and isinstance(out_w, int)
         ctx.spatial_scale = spatial_scale
         ctx.sample_num = sample_num
         ctx.save_for_backward(rois)
diff --git a/mmdet/ops/roi_align/modules/roi_align.py b/mmdet/ops/roi_align/modules/roi_align.py
index b83b74e6b7c151eaf627c2b6d3530823ce8cda05..de987bd456c88a093632a96b0fcc57b2a3190e87 100644
--- a/mmdet/ops/roi_align/modules/roi_align.py
+++ b/mmdet/ops/roi_align/modules/roi_align.py
@@ -1,16 +1,28 @@
-from torch.nn.modules.module import Module
-from ..functions.roi_align import RoIAlignFunction
+import torch.nn as nn
+from torch.nn.modules.utils import _pair
 
+from ..functions.roi_align import roi_align
 
-class RoIAlign(Module):
 
-    def __init__(self, out_size, spatial_scale, sample_num=0):
+class RoIAlign(nn.Module):
+
+    def __init__(self,
+                 out_size,
+                 spatial_scale,
+                 sample_num=0,
+                 use_torchvision=False):
         super(RoIAlign, self).__init__()
 
         self.out_size = out_size
         self.spatial_scale = float(spatial_scale)
         self.sample_num = int(sample_num)
+        self.use_torchvision = use_torchvision
 
     def forward(self, features, rois):
-        return RoIAlignFunction.apply(features, rois, self.out_size,
-                                      self.spatial_scale, self.sample_num)
+        if self.use_torchvision:
+            from torchvision.ops import roi_align as tv_roi_align
+            return tv_roi_align(features, rois, _pair(self.out_size),
+                                self.spatial_scale, self.sample_num)
+        else:
+            return roi_align(features, rois, self.out_size, self.spatial_scale,
+                             self.sample_num)
diff --git a/mmdet/ops/roi_pool/functions/roi_pool.py b/mmdet/ops/roi_pool/functions/roi_pool.py
index 068da600e5828d88ef1477c1afe19b81ee363ee2..6de40088c62828f917937a12d9ed2708ce2b85c3 100644
--- a/mmdet/ops/roi_pool/functions/roi_pool.py
+++ b/mmdet/ops/roi_pool/functions/roi_pool.py
@@ -1,5 +1,6 @@
 import torch
 from torch.autograd import Function
+from torch.nn.modules.utils import _pair
 
 from .. import roi_pool_cuda
 
@@ -8,18 +9,9 @@ class RoIPoolFunction(Function):
 
     @staticmethod
     def forward(ctx, features, rois, out_size, spatial_scale):
-        if isinstance(out_size, int):
-            out_h = out_size
-            out_w = out_size
-        elif isinstance(out_size, tuple):
-            assert len(out_size) == 2
-            assert isinstance(out_size[0], int)
-            assert isinstance(out_size[1], int)
-            out_h, out_w = out_size
-        else:
-            raise TypeError(
-                '"out_size" must be an integer or tuple of integers')
         assert features.is_cuda
+        out_h, out_w = _pair(out_size)
+        assert isinstance(out_h, int) and isinstance(out_w, int)
         ctx.save_for_backward(rois)
         num_channels = features.size(1)
         num_rois = rois.size(0)
diff --git a/mmdet/ops/roi_pool/modules/roi_pool.py b/mmdet/ops/roi_pool/modules/roi_pool.py
index d7fffd08c656ee7301aeed5a8262714f4be4157d..c173cbbfd7e9c4f8a7f5cdedc4258fa7e2ccbad2 100644
--- a/mmdet/ops/roi_pool/modules/roi_pool.py
+++ b/mmdet/ops/roi_pool/modules/roi_pool.py
@@ -1,14 +1,22 @@
-from torch.nn.modules.module import Module
+import torch.nn as nn
+from torch.nn.modules.utils import _pair
+
 from ..functions.roi_pool import roi_pool
 
 
-class RoIPool(Module):
+class RoIPool(nn.Module):
 
-    def __init__(self, out_size, spatial_scale):
+    def __init__(self, out_size, spatial_scale, use_torchvision=False):
         super(RoIPool, self).__init__()
 
         self.out_size = out_size
         self.spatial_scale = float(spatial_scale)
+        self.use_torchvision = use_torchvision
 
     def forward(self, features, rois):
-        return roi_pool(features, rois, self.out_size, self.spatial_scale)
+        if self.use_torchvision:
+            from torchvision.ops import roi_pool as tv_roi_pool
+            return tv_roi_pool(features, rois, _pair(self.out_size),
+                               self.spatial_scale)
+        else:
+            return roi_pool(features, rois, self.out_size, self.spatial_scale)