diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md
index 77eb71492b242bea660013fa54f6397201358c65..b07d90d8e42b1bf8b00db8f9ac77ab1f48b96527 100644
--- a/docs/GETTING_STARTED.md
+++ b/docs/GETTING_STARTED.md
@@ -74,6 +74,7 @@ python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \
 
 ### High-level APIs for testing images
 
+#### Synchronous interface
 Here is an example of building the model and test given images.
 
 ```python
@@ -103,6 +104,48 @@ for frame in video:
 
 A notebook demo can be found in [demo/inference_demo.ipynb](../demo/inference_demo.ipynb).
 
+#### Asynchronous interface - supported for Python 3.7+
+
+Async interface allows not to block CPU on GPU bound inference code and enables better CPU/GPU utilization for single threaded application. Inference can be done concurrently either between different input data samples or between different models of some inference pipeline.
+
+See `tests/async_benchmark.py` to compare the speed of synchronous and asynchronous interfaces.
+
+```python
+import asyncio
+import torch
+from mmdet.apis import init_detector, async_inference_detector, show_result
+from mmdet.utils.contextmanagers import concurrent
+
+async def main():
+    config_file = 'configs/faster_rcnn_r50_fpn_1x.py'
+    checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth'
+    device = 'cuda:0'
+    model = init_detector(config_file, checkpoint=checkpoint_file, device=device)
+
+    # queue is used for concurrent inference of multiple images
+    streamqueue = asyncio.Queue()
+    # queue size defines concurrency level
+    streamqueue_size = 3
+
+    for _ in range(streamqueue_size):
+        streamqueue.put_nowait(torch.cuda.Stream(device=device))
+
+    # test a single image and show the results
+    img = 'test.jpg'  # or img = mmcv.imread(img), which will only load it once
+
+    async with concurrent(streamqueue):
+        result = await async_inference_detector(model, img)
+
+    # visualize the results in a new window
+    show_result(img, result, model.CLASSES)
+    # or save the visualization results to image files
+    show_result(img, result, model.CLASSES, out_file='result.jpg')
+
+
+asyncio.run(main())
+
+```
+
 
 ## Train a model
 
diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
index f85c7af686996d91faa1939e252479ac699e5d39..4cdf847b25d4f11b5d19ce28c3c35253035bbc12 100644
--- a/mmdet/apis/__init__.py
+++ b/mmdet/apis/__init__.py
@@ -1,9 +1,10 @@
 from .env import get_root_logger, init_dist, set_random_seed
-from .inference import (inference_detector, init_detector, show_result,
-                        show_result_pyplot)
+from .inference import (async_inference_detector, inference_detector,
+                        init_detector, show_result, show_result_pyplot)
 from .train import train_detector
 
 __all__ = [
-    'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
-    'init_detector', 'inference_detector', 'show_result', 'show_result_pyplot'
+    'async_inference_detector', 'init_dist', 'get_root_logger',
+    'set_random_seed', 'train_detector', 'init_detector', 'inference_detector',
+    'show_result', 'show_result_pyplot'
 ]
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index ccf228a77eb42266dd5e61cc7860ca4af34d7036..6724e85c85009366cef8c070a466f8d1c2b65ff8 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -84,7 +84,34 @@ def inference_detector(model, img):
     # forward the model
     with torch.no_grad():
         result = model(return_loss=False, rescale=True, **data)
+    return result
+
+
+async def async_inference_detector(model, img):
+    """Async inference image(s) with the detector.
+
+    Args:
+        model (nn.Module): The loaded detector.
+        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+            images.
+
+    Returns:
+        Awaitable detection results.
+    """
+    cfg = model.cfg
+    device = next(model.parameters()).device  # model device
+    # build the data pipeline
+    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
+    test_pipeline = Compose(test_pipeline)
+    # prepare data
+    data = dict(img=img)
+    data = test_pipeline(data)
+    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
 
+    # We don't restore `torch.is_grad_enabled()` value during concurrent
+    # inference since execution can overlap
+    torch.set_grad_enabled(False)
+    result = await model.aforward_test(rescale=True, **data)
     return result
 
 
diff --git a/mmdet/models/anchor_heads/rpn_head.py b/mmdet/models/anchor_heads/rpn_head.py
index 50f1cc515e05f0fa4e485960db2035156c9434e7..f88b949cf8b3051610697d15772bf1b7ea938a06 100644
--- a/mmdet/models/anchor_heads/rpn_head.py
+++ b/mmdet/models/anchor_heads/rpn_head.py
@@ -65,7 +65,6 @@ class RPNHead(AnchorHead):
             rpn_cls_score = cls_scores[idx]
             rpn_bbox_pred = bbox_preds[idx]
             assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
-            anchors = mlvl_anchors[idx]
             rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
             if self.use_sigmoid_cls:
                 rpn_cls_score = rpn_cls_score.reshape(-1)
@@ -74,6 +73,7 @@ class RPNHead(AnchorHead):
                 rpn_cls_score = rpn_cls_score.reshape(-1, 2)
                 scores = rpn_cls_score.softmax(dim=1)[:, 1]
             rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+            anchors = mlvl_anchors[idx]
             if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
                 _, topk_inds = scores.topk(cfg.nms_pre)
                 rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 42f47dde8ed5043dd2c2ccc9ec11836ea749028c..6a33dab4c59ec8551dc3fa5e5a8965954c063d56 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -61,6 +61,10 @@ class BaseDetector(nn.Module):
         """
         pass
 
+    @abstractmethod
+    async def async_simple_test(self, img, img_meta, **kwargs):
+        pass
+
     @abstractmethod
     def simple_test(self, img, img_meta, **kwargs):
         pass
@@ -74,6 +78,26 @@ class BaseDetector(nn.Module):
             logger = logging.getLogger()
             logger.info('load model from: {}'.format(pretrained))
 
+    async def aforward_test(self, *, img, img_meta, **kwargs):
+        for var, name in [(img, 'img'), (img_meta, 'img_meta')]:
+            if not isinstance(var, list):
+                raise TypeError('{} must be a list, but got {}'.format(
+                    name, type(var)))
+
+        num_augs = len(img)
+        if num_augs != len(img_meta):
+            raise ValueError(
+                'num of augmentations ({}) != num of image meta ({})'.format(
+                    len(img), len(img_meta)))
+        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
+        imgs_per_gpu = img[0].size(0)
+        assert imgs_per_gpu == 1
+
+        if num_augs == 1:
+            return await self.async_simple_test(img[0], img_meta[0], **kwargs)
+        else:
+            raise NotImplementedError
+
     def forward_test(self, imgs, img_metas, **kwargs):
         """
         Args:
diff --git a/mmdet/models/detectors/test_mixins.py b/mmdet/models/detectors/test_mixins.py
index 05b2081271e09ed9c845c593cbe4946a5ab31e97..84a96d1679044c658cc6b7014006ef73b210054a 100644
--- a/mmdet/models/detectors/test_mixins.py
+++ b/mmdet/models/detectors/test_mixins.py
@@ -1,11 +1,33 @@
+import logging
+import sys
+
 import torch
 
 from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
                         merge_aug_masks, merge_aug_proposals, multiclass_nms)
 
+logger = logging.getLogger(__name__)
+
+if sys.version_info >= (3, 7):
+    from mmdet.utils.contextmanagers import completed
+
 
 class RPNTestMixin(object):
 
+    if sys.version_info >= (3, 7):
+
+        async def async_test_rpn(self, x, img_meta, rpn_test_cfg):
+            sleep_interval = rpn_test_cfg.pop("async_sleep_interval", 0.025)
+            async with completed(
+                    __name__, "rpn_head_forward",
+                    sleep_interval=sleep_interval):
+                rpn_outs = self.rpn_head(x)
+
+            proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
+
+            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
+            return proposal_list
+
     def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
         rpn_outs = self.rpn_head(x)
         proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
@@ -37,6 +59,41 @@ class RPNTestMixin(object):
 
 class BBoxTestMixin(object):
 
+    if sys.version_info >= (3, 7):
+
+        async def async_test_bboxes(self,
+                                    x,
+                                    img_meta,
+                                    proposals,
+                                    rcnn_test_cfg,
+                                    rescale=False,
+                                    bbox_semaphore=None,
+                                    global_lock=None):
+            """Async test only det bboxes without augmentation."""
+            rois = bbox2roi(proposals)
+            roi_feats = self.bbox_roi_extractor(
+                x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
+            if self.with_shared_head:
+                roi_feats = self.shared_head(roi_feats)
+            sleep_interval = rcnn_test_cfg.get("async_sleep_interval", 0.017)
+
+            async with completed(
+                    __name__, "bbox_head_forward",
+                    sleep_interval=sleep_interval):
+                cls_score, bbox_pred = self.bbox_head(roi_feats)
+
+            img_shape = img_meta[0]['img_shape']
+            scale_factor = img_meta[0]['scale_factor']
+            det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
+                rois,
+                cls_score,
+                bbox_pred,
+                img_shape,
+                scale_factor,
+                rescale=rescale,
+                cfg=rcnn_test_cfg)
+            return det_bboxes, det_labels
+
     def simple_test_bboxes(self,
                            x,
                            img_meta,
@@ -102,6 +159,46 @@ class BBoxTestMixin(object):
 
 class MaskTestMixin(object):
 
+    if sys.version_info >= (3, 7):
+
+        async def async_test_mask(self,
+                                  x,
+                                  img_meta,
+                                  det_bboxes,
+                                  det_labels,
+                                  rescale=False,
+                                  mask_test_cfg=None):
+            # image shape of the first image in the batch (only one)
+            ori_shape = img_meta[0]['ori_shape']
+            scale_factor = img_meta[0]['scale_factor']
+            if det_bboxes.shape[0] == 0:
+                segm_result = [[]
+                               for _ in range(self.mask_head.num_classes - 1)]
+            else:
+                _bboxes = (
+                    det_bboxes[:, :4] *
+                    scale_factor if rescale else det_bboxes)
+                mask_rois = bbox2roi([_bboxes])
+                mask_feats = self.mask_roi_extractor(
+                    x[:len(self.mask_roi_extractor.featmap_strides)],
+                    mask_rois)
+
+                if self.with_shared_head:
+                    mask_feats = self.shared_head(mask_feats)
+                if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'):
+                    sleep_interval = mask_test_cfg['async_sleep_interval']
+                else:
+                    sleep_interval = 0.035
+                async with completed(
+                        __name__,
+                        "mask_head_forward",
+                        sleep_interval=sleep_interval):
+                    mask_pred = self.mask_head(mask_feats)
+                segm_result = self.mask_head.get_seg_masks(
+                    mask_pred, _bboxes, det_labels, self.test_cfg.rcnn,
+                    ori_shape, scale_factor, rescale)
+            return segm_result
+
     def simple_test_mask(self,
                          x,
                          img_meta,
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index 6a75eb3640f4ee552ca84b107f7ed3b56fb77f54..155819518a2879eeb1fadf57c177f74deb22c48d 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -260,14 +260,49 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
 
         return losses
 
+    async def async_simple_test(self,
+                                img,
+                                img_meta,
+                                proposals=None,
+                                rescale=False):
+        """Async test without augmentation."""
+        assert self.with_bbox, "Bbox head must be implemented."
+        x = self.extract_feat(img)
+
+        if proposals is None:
+            proposal_list = await self.async_test_rpn(x, img_meta,
+                                                      self.test_cfg.rpn)
+        else:
+            proposal_list = proposals
+
+        det_bboxes, det_labels = await self.async_test_bboxes(
+            x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
+        bbox_results = bbox2result(det_bboxes, det_labels,
+                                   self.bbox_head.num_classes)
+
+        if not self.with_mask:
+            return bbox_results
+        else:
+            segm_results = await self.async_test_mask(
+                x,
+                img_meta,
+                det_bboxes,
+                det_labels,
+                rescale=rescale,
+                mask_test_cfg=self.test_cfg.get('mask'))
+            return bbox_results, segm_results
+
     def simple_test(self, img, img_meta, proposals=None, rescale=False):
         """Test without augmentation."""
         assert self.with_bbox, "Bbox head must be implemented."
 
         x = self.extract_feat(img)
 
-        proposal_list = self.simple_test_rpn(
-            x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
+        if proposals is None:
+            proposal_list = self.simple_test_rpn(x, img_meta,
+                                                 self.test_cfg.rpn)
+        else:
+            proposal_list = proposals
 
         det_bboxes, det_labels = self.simple_test_bboxes(
             x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
diff --git a/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu b/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
index 402b4499cae114f2d2b30f01f182a5198c910298..e7a26f2e830846f80272bcd8c5ce0def34593c95 100644
--- a/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+++ b/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -61,6 +61,7 @@
 // modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
 
 #include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
 #include <THC/THCAtomics.cuh>
 #include <stdio.h>
 #include <math.h>
@@ -261,7 +262,7 @@ void deformable_im2col(
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
         scalar_t *data_col_ = data_col.data<scalar_t>();
 
-        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
             pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
             channel_per_deformable_group, parallel_imgs, channels, deformable_group,
@@ -355,7 +356,7 @@ void deformable_col2im(
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
         scalar_t *grad_im_ = grad_im.data<scalar_t>();
 
-        deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+        deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
             ksize_w, pad_h, pad_w, stride_h, stride_w,
             dilation_h, dilation_w, channel_per_deformable_group,
@@ -454,7 +455,7 @@ void deformable_col2im_coord(
         const scalar_t *data_offset_ = data_offset.data<scalar_t>();
         scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
 
-        deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+        deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
             ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
             dilation_h, dilation_w, channel_per_deformable_group,
@@ -784,7 +785,7 @@ void modulated_deformable_im2col_cuda(
         const scalar_t *data_mask_ = data_mask.data<scalar_t>();
         scalar_t *data_col_ = data_col.data<scalar_t>();
 
-        modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+        modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
             pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
             batch_size, channels, deformable_group, height_col, width_col, data_col_);
@@ -816,7 +817,7 @@ void modulated_deformable_col2im_cuda(
         const scalar_t *data_mask_ = data_mask.data<scalar_t>();
         scalar_t *grad_im_ = grad_im.data<scalar_t>();
 
-        modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+        modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
             kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
             dilation_h, dilation_w, channel_per_deformable_group,
@@ -851,7 +852,7 @@ void modulated_deformable_col2im_coord_cuda(
         scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
         scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
 
-        modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
+        modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
             kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
             dilation_h, dilation_w, channel_per_deformable_group,
diff --git a/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu b/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu
index 1922d72442253353b29200a00a204b34a9153fa4..05b00d4be618353b404540469bf6118902651ca2 100644
--- a/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu
+++ b/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu
@@ -296,7 +296,7 @@ void DeformablePSROIPoolForward(const at::Tensor data,
         scalar_t *top_data = out.data<scalar_t>();
         scalar_t *top_count_data = top_count.data<scalar_t>();
 
-        DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
+        DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
             bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
             group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
@@ -349,7 +349,7 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
         scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>();
         const scalar_t *top_count_data = top_count.data<scalar_t>();
 
-        DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
+        DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
             count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
             pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
             bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
@@ -361,4 +361,4 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
   {
     printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
   }
-}
\ No newline at end of file
+}
diff --git a/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu b/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu
index 2312d1200c926d7c3ecc5fc7655b344f63de509e..81c785bbe41461fa8a4d380dbbef60dbe677cf6a 100644
--- a/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu
+++ b/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu
@@ -1,4 +1,5 @@
 #include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
 #include <THC/THCAtomics.cuh>
 
 #define CUDA_1D_KERNEL_LOOP(i, n)                            \
@@ -63,7 +64,8 @@ int MaskedIm2colForwardLaucher(const at::Tensor bottom_data, const int height,
         const int64_t *mask_w_idx_ = mask_w_idx.data<int64_t>();
         scalar_t *top_data_ = top_data.data<scalar_t>();
         MaskedIm2colForward<scalar_t>
-            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()
+>>>(
                 output_size, bottom_data_, height, width, kernel_h, kernel_w,
                 pad_h, pad_w, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_);
       }));
@@ -103,7 +105,7 @@ int MaskedCol2imForwardLaucher(const at::Tensor bottom_data, const int height,
         scalar_t *top_data_ = top_data.data<scalar_t>();
 
         MaskedCol2imForward<scalar_t>
-            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
                 output_size, bottom_data_, height, width, channels, mask_h_idx_,
                 mask_w_idx_, mask_cnt, top_data_);
       }));
diff --git a/mmdet/ops/nms/src/nms_kernel.cu b/mmdet/ops/nms/src/nms_kernel.cu
index c0811cfcb4802355f5784d1e737374a4b08ebb55..ada9bea25237046b4001fea0e0c7061c6a53886f 100644
--- a/mmdet/ops/nms/src/nms_kernel.cu
+++ b/mmdet/ops/nms/src/nms_kernel.cu
@@ -96,16 +96,19 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
   dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
               THCCeilDiv(boxes_num, threadsPerBlock));
   dim3 threads(threadsPerBlock);
-  nms_kernel<<<blocks, threads>>>(boxes_num,
+  nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
                                   nms_overlap_thresh,
                                   boxes_dev,
                                   mask_dev);
 
   std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
-  THCudaCheck(cudaMemcpy(&mask_host[0],
-                        mask_dev,
-                        sizeof(unsigned long long) * boxes_num * col_blocks,
-                        cudaMemcpyDeviceToHost));
+  THCudaCheck(cudaMemcpyAsync(
+			  &mask_host[0],
+			  mask_dev,
+			  sizeof(unsigned long long) * boxes_num * col_blocks,
+			  cudaMemcpyDeviceToHost,
+			  at::cuda::getCurrentCUDAStream()
+			  ));
 
   std::vector<unsigned long long> remv(col_blocks);
   memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
diff --git a/mmdet/ops/roi_align/src/roi_align_cuda.cpp b/mmdet/ops/roi_align/src/roi_align_cuda.cpp
index 06a73aa05e49b547cc734af3e223036edb29f3b2..12c16f009d7dc9e8b0298beb5385e8cf0256becc 100644
--- a/mmdet/ops/roi_align/src/roi_align_cuda.cpp
+++ b/mmdet/ops/roi_align/src/roi_align_cuda.cpp
@@ -1,5 +1,7 @@
 #include <torch/extension.h>
 
+#include <ATen/ATen.h>
+
 #include <cmath>
 #include <vector>
 
diff --git a/mmdet/ops/roi_align/src/roi_align_kernel.cu b/mmdet/ops/roi_align/src/roi_align_kernel.cu
index eb7cdaf1ffa5bc1ff25f1c873d0f23e9590b0912..3208b2806155cbdd46e1c9230ad0c4db03d33f51 100644
--- a/mmdet/ops/roi_align/src/roi_align_kernel.cu
+++ b/mmdet/ops/roi_align/src/roi_align_kernel.cu
@@ -1,4 +1,5 @@
 #include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
 #include <THC/THCAtomics.cuh>
 
 #define CUDA_1D_KERNEL_LOOP(i, n)                            \
@@ -131,7 +132,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
         scalar_t *top_data = output.data<scalar_t>();
 
         ROIAlignForward<scalar_t>
-            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
                 output_size, bottom_data, rois_data, scalar_t(spatial_scale),
                 sample_num, channels, height, width, pooled_height,
                 pooled_width, top_data);
@@ -272,7 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
         }
 
         ROIAlignBackward<scalar_t>
-            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
                 output_size, top_diff, rois_data, spatial_scale, sample_num,
                 channels, height, width, pooled_height, pooled_width,
                 bottom_diff);
diff --git a/mmdet/ops/roi_pool/src/roi_pool_kernel.cu b/mmdet/ops/roi_pool/src/roi_pool_kernel.cu
index 25ba98532850422f3d54fc3b00c35ebdf1dd8ae3..37c3d0c4926696bbf3ca7cbad37ac4991ef80055 100644
--- a/mmdet/ops/roi_pool/src/roi_pool_kernel.cu
+++ b/mmdet/ops/roi_pool/src/roi_pool_kernel.cu
@@ -1,4 +1,5 @@
 #include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
 #include <THC/THCAtomics.cuh>
 
 #define CUDA_1D_KERNEL_LOOP(i, n)                            \
@@ -93,7 +94,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
         int *argmax_data = argmax.data<int>();
 
         ROIPoolForward<scalar_t>
-            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
                 output_size, bottom_data, rois_data, scalar_t(spatial_scale),
                 channels, height, width, pooled_h, pooled_w, top_data,
                 argmax_data);
@@ -146,7 +147,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
         }
 
         ROIPoolBackward<scalar_t>
-            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
+            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>(
                 output_size, top_diff, rois_data, argmax_data,
                 scalar_t(spatial_scale), channels, height, width, pooled_h,
                 pooled_w, bottom_diff);
diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
index 6a9104c1cd9ce1e817e1b75791fd822f819d9a85..0e152d38fd15858c4b8c51936aa68f78796a319a 100644
--- a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
+++ b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu
@@ -120,7 +120,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
       logits.scalar_type(), "SigmoidFocalLoss_forward", [&] {
-        SigmoidFocalLossForward<scalar_t><<<grid, block>>>(
+        SigmoidFocalLossForward<scalar_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
             losses_size, logits.contiguous().data<scalar_t>(),
             targets.contiguous().data<int64_t>(), num_classes, gamma, alpha,
             num_samples, losses.data<scalar_t>());
@@ -159,7 +159,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
       logits.scalar_type(), "SigmoidFocalLoss_backward", [&] {
-        SigmoidFocalLossBackward<scalar_t><<<grid, block>>>(
+        SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
             d_logits_size, logits.contiguous().data<scalar_t>(),
             targets.contiguous().data<int64_t>(),
             d_losses.contiguous().data<scalar_t>(), num_classes, gamma, alpha,
diff --git a/mmdet/utils/contextmanagers.py b/mmdet/utils/contextmanagers.py
new file mode 100644
index 0000000000000000000000000000000000000000..12073bef93219fedd96713aa7c6452c1530679a5
--- /dev/null
+++ b/mmdet/utils/contextmanagers.py
@@ -0,0 +1,122 @@
+# coding: utf-8
+import asyncio
+import contextlib
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
+
+
+@contextlib.asynccontextmanager
+async def completed(trace_name="",
+                    name="",
+                    sleep_interval=0.05,
+                    streams: List[torch.cuda.Stream] = None):
+    """
+    Async context manager that waits for work to complete on
+    given CUDA streams.
+
+    """
+    if not torch.cuda.is_available():
+        yield
+        return
+
+    stream_before_context_switch = torch.cuda.current_stream()
+    if not streams:
+        streams = [stream_before_context_switch]
+    else:
+        streams = [s if s else stream_before_context_switch for s in streams]
+
+    end_events = [
+        torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
+    ]
+
+    if DEBUG_COMPLETED_TIME:
+        start = torch.cuda.Event(enable_timing=True)
+        stream_before_context_switch.record_event(start)
+
+        cpu_start = time.monotonic()
+    logger.debug("%s %s starting, streams: %s", trace_name, name, streams)
+    grad_enabled_before = torch.is_grad_enabled()
+    try:
+        yield
+    finally:
+        current_stream = torch.cuda.current_stream()
+        assert current_stream == stream_before_context_switch
+
+        if DEBUG_COMPLETED_TIME:
+            cpu_end = time.monotonic()
+        for i, stream in enumerate(streams):
+            event = end_events[i]
+            stream.record_event(event)
+
+        grad_enabled_after = torch.is_grad_enabled()
+
+        # observed change of torch.is_grad_enabled() during concurrent run of
+        # async_test_bboxes code
+        assert grad_enabled_before == grad_enabled_after, \
+            "Unexpected is_grad_enabled() value change"
+
+        are_done = [e.query() for e in end_events]
+        logger.debug("%s %s completed: %s streams: %s", trace_name, name,
+                     are_done, streams)
+        with torch.cuda.stream(stream_before_context_switch):
+            while not all(are_done):
+                await asyncio.sleep(sleep_interval)
+                are_done = [e.query() for e in end_events]
+                logger.debug("%s %s completed: %s streams: %s", trace_name,
+                             name, are_done, streams)
+
+        current_stream = torch.cuda.current_stream()
+        assert current_stream == stream_before_context_switch
+
+        if DEBUG_COMPLETED_TIME:
+            cpu_time = (cpu_end - cpu_start) * 1000
+            stream_times_ms = ""
+            for i, stream in enumerate(streams):
+                elapsed_time = start.elapsed_time(end_events[i])
+                stream_times_ms += " {stream} {elapsed_time:.2f} ms".format(
+                    stream, elapsed_time)
+            logger.info("{trace_name} {name} cpu_time {cpu_time:.2f} ms",
+                        trace_name, name, cpu_time, stream_times_ms)
+
+
+@contextlib.asynccontextmanager
+async def concurrent(streamqueue: asyncio.Queue,
+                     trace_name="concurrent",
+                     name="stream"):
+    """Run code concurrently in different streams.
+
+    :param streamqueue: asyncio.Queue instance.
+
+    Queue tasks define the pool of streams used for concurrent execution.
+
+    """
+    if not torch.cuda.is_available():
+        yield
+        return
+
+    initial_stream = torch.cuda.current_stream()
+
+    with torch.cuda.stream(initial_stream):
+        stream = await streamqueue.get()
+        assert isinstance(stream, torch.cuda.Stream)
+
+        try:
+            with torch.cuda.stream(stream):
+                logger.debug("%s %s is starting, stream: %s", trace_name, name,
+                             stream)
+                yield
+                current = torch.cuda.current_stream()
+                assert current == stream
+                logger.debug("%s %s has finished, stream: %s", trace_name,
+                             name, stream)
+        finally:
+            streamqueue.task_done()
+            streamqueue.put_nowait(stream)
diff --git a/mmdet/utils/profiling.py b/mmdet/utils/profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..58b1c87dd3f7a427bfa59b003b58e29158666d36
--- /dev/null
+++ b/mmdet/utils/profiling.py
@@ -0,0 +1,41 @@
+import contextlib
+import sys
+import time
+
+import torch
+
+if sys.version_info >= (3, 7):
+
+    @contextlib.contextmanager
+    def profile_time(trace_name,
+                     name,
+                     enabled=True,
+                     stream=None,
+                     end_stream=None):
+        """Print time spent by CPU and GPU.
+
+        Useful as a temporary context manager to find sweet spots of
+        code suitable for async implementation.
+
+        """
+        if (not enabled) or not torch.cuda.is_available():
+            yield
+            return
+        stream = stream if stream else torch.cuda.current_stream()
+        end_stream = end_stream if end_stream else stream
+        start = torch.cuda.Event(enable_timing=True)
+        end = torch.cuda.Event(enable_timing=True)
+        stream.record_event(start)
+        try:
+            cpu_start = time.monotonic()
+            yield
+        finally:
+            cpu_end = time.monotonic()
+            end_stream.record_event(end)
+            end.synchronize()
+            cpu_time = (cpu_end - cpu_start) * 1000
+            gpu_time = start.elapsed_time(end)
+            msg = "{} {} cpu_time {:.2f} ms ".format(trace_name, name,
+                                                     cpu_time)
+            msg += "gpu_time {:.2f} ms stream {}".format(gpu_time, stream)
+            print(msg, end_stream)
diff --git a/setup.py b/setup.py
index 43a2ec219241d62854c027de1345fcddf6dfc02b..54d58777268ae33f0da7cdc7135f9f4969e5a1d1 100644
--- a/setup.py
+++ b/setup.py
@@ -162,7 +162,7 @@ if __name__ == '__main__':
         ],
         license='Apache License 2.0',
         setup_requires=['pytest-runner', 'cython', 'numpy'],
-        tests_require=['pytest', 'xdoctest'],
+        tests_require=['pytest', 'xdoctest', 'asynctest'],
         install_requires=get_requirements(),
         ext_modules=[
             make_cuda_ext(
diff --git a/tests/async_benchmark.py b/tests/async_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..0017783d354d24a0c9f38953a7566f7545a41bb6
--- /dev/null
+++ b/tests/async_benchmark.py
@@ -0,0 +1,104 @@
+# coding: utf-8
+
+import asyncio
+import os
+import shutil
+import urllib
+
+import mmcv
+import torch
+
+from mmdet.apis import (async_inference_detector, inference_detector,
+                        init_detector, show_result)
+from mmdet.utils.contextmanagers import concurrent
+from mmdet.utils.profiling import profile_time
+
+
+async def main():
+    """
+
+    Benchmark between async and synchronous inference interfaces.
+
+    Sample runs for 20 demo images on K80 GPU, model - mask_rcnn_r50_fpn_1x:
+
+    async	sync
+
+    7981.79 ms	9660.82 ms
+    8074.52 ms	9660.94 ms
+    7976.44 ms	9406.83 ms
+
+    Async variant takes about 0.83-0.85 of the time of the synchronous
+    interface.
+
+    """
+    project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
+
+    config_file = os.path.join(project_dir, 'configs/mask_rcnn_r50_fpn_1x.py')
+    checkpoint_file = os.path.join(
+        project_dir, 'checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth')
+
+    if not os.path.exists(checkpoint_file):
+        url = ('https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection'
+               '/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth')
+        print('Downloading {} ...'.format(url))
+        local_filename, _ = urllib.request.urlretrieve(url)
+        os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
+        shutil.move(local_filename, checkpoint_file)
+        print('Saved as {}'.format(checkpoint_file))
+    else:
+        print('Using existing checkpoint {}'.format(checkpoint_file))
+
+    device = 'cuda:0'
+    model = init_detector(
+        config_file, checkpoint=checkpoint_file, device=device)
+
+    # queue is used for concurrent inference of multiple images
+    streamqueue = asyncio.Queue()
+    # queue size defines concurrency level
+    streamqueue_size = 4
+
+    for _ in range(streamqueue_size):
+        streamqueue.put_nowait(torch.cuda.Stream(device=device))
+
+    # test a single image and show the results
+    img = mmcv.imread(os.path.join(project_dir, 'demo/demo.jpg'))
+
+    # warmup
+    await async_inference_detector(model, img)
+
+    async def detect(img):
+        async with concurrent(streamqueue):
+            return await async_inference_detector(model, img)
+
+    num_of_images = 20
+    with profile_time('benchmark', 'async'):
+        tasks = [
+            asyncio.create_task(detect(img)) for _ in range(num_of_images)
+        ]
+        async_results = await asyncio.gather(*tasks)
+
+    with torch.cuda.stream(torch.cuda.default_stream()):
+        with profile_time('benchmark', 'sync'):
+            sync_results = [
+                inference_detector(model, img) for _ in range(num_of_images)
+            ]
+
+    result_dir = os.path.join(project_dir, 'demo')
+    show_result(
+        img,
+        async_results[0],
+        model.CLASSES,
+        score_thr=0.5,
+        show=False,
+        out_file=os.path.join(result_dir, 'result_async.jpg'))
+    show_result(
+        img,
+        sync_results[0],
+        model.CLASSES,
+        score_thr=0.5,
+        show=False,
+        out_file=os.path.join(result_dir, 'result_sync.jpg'))
+
+
+if __name__ == '__main__':
+    asyncio.run(main())
diff --git a/tests/requirements.txt b/tests/requirements.txt
index a1f3efbbe93d05ef4e1eceeb8d5640b56daf1d3e..ff60968eef8a22c8f8fa98c090b5ce9702ecd6b4 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -4,3 +4,4 @@ yapf
 pytest-cov
 codecov
 xdoctest >= 0.10.0
+asynctest
diff --git a/tests/test_async.py b/tests/test_async.py
new file mode 100644
index 0000000000000000000000000000000000000000..68ecde33d0b5d937041060275b7c23dbbc471a16
--- /dev/null
+++ b/tests/test_async.py
@@ -0,0 +1,78 @@
+"""Tests for async interface."""
+
+import asyncio
+import os
+import sys
+
+import asynctest
+import mmcv
+import torch
+
+from mmdet.apis import async_inference_detector, init_detector
+
+if sys.version_info >= (3, 7):
+    from mmdet.utils.contextmanagers import concurrent
+
+
+class AsyncTestCase(asynctest.TestCase):
+    use_default_loop = False
+    forbid_get_event_loop = True
+
+    TEST_TIMEOUT = int(os.getenv("ASYNCIO_TEST_TIMEOUT", "30"))
+
+    def _run_test_method(self, method):
+        result = method()
+        if asyncio.iscoroutine(result):
+            self.loop.run_until_complete(
+                asyncio.wait_for(result, timeout=self.TEST_TIMEOUT))
+
+
+class MaskRCNNDetector:
+
+    def __init__(self,
+                 model_config,
+                 checkpoint=None,
+                 streamqueue_size=3,
+                 device="cuda:0"):
+
+        self.streamqueue_size = streamqueue_size
+        self.device = device
+        # build the model and load checkpoint
+        self.model = init_detector(
+            model_config, checkpoint=None, device=self.device)
+        self.streamqueue = None
+
+    async def init(self):
+        self.streamqueue = asyncio.Queue()
+        for _ in range(self.streamqueue_size):
+            stream = torch.cuda.Stream(device=self.device)
+            self.streamqueue.put_nowait(stream)
+
+    if sys.version_info >= (3, 7):
+
+        async def apredict(self, img):
+            if isinstance(img, str):
+                img = mmcv.imread(img)
+            async with concurrent(self.streamqueue):
+                result = await async_inference_detector(self.model, img)
+            return result
+
+
+class AsyncInferenceTestCase(AsyncTestCase):
+
+    if sys.version_info >= (3, 7):
+
+        async def test_simple_inference(self):
+            if not torch.cuda.is_available():
+                import pytest
+
+                pytest.skip("test requires GPU and torch+cuda")
+
+            root_dir = os.path.dirname(os.path.dirname(__name__))
+            model_config = os.path.join(root_dir,
+                                        "configs/mask_rcnn_r50_fpn_1x.py")
+            detector = MaskRCNNDetector(model_config)
+            await detector.init()
+            img_path = os.path.join(root_dir, "demo/demo.jpg")
+            bboxes, _ = await detector.apredict(img_path)
+            self.assertTrue(bboxes)