diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 77eb71492b242bea660013fa54f6397201358c65..b07d90d8e42b1bf8b00db8f9ac77ab1f48b96527 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -74,6 +74,7 @@ python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \ ### High-level APIs for testing images +#### Synchronous interface Here is an example of building the model and test given images. ```python @@ -103,6 +104,48 @@ for frame in video: A notebook demo can be found in [demo/inference_demo.ipynb](../demo/inference_demo.ipynb). +#### Asynchronous interface - supported for Python 3.7+ + +Async interface allows not to block CPU on GPU bound inference code and enables better CPU/GPU utilization for single threaded application. Inference can be done concurrently either between different input data samples or between different models of some inference pipeline. + +See `tests/async_benchmark.py` to compare the speed of synchronous and asynchronous interfaces. + +```python +import asyncio +import torch +from mmdet.apis import init_detector, async_inference_detector, show_result +from mmdet.utils.contextmanagers import concurrent + +async def main(): + config_file = 'configs/faster_rcnn_r50_fpn_1x.py' + checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth' + device = 'cuda:0' + model = init_detector(config_file, checkpoint=checkpoint_file, device=device) + + # queue is used for concurrent inference of multiple images + streamqueue = asyncio.Queue() + # queue size defines concurrency level + streamqueue_size = 3 + + for _ in range(streamqueue_size): + streamqueue.put_nowait(torch.cuda.Stream(device=device)) + + # test a single image and show the results + img = 'test.jpg' # or img = mmcv.imread(img), which will only load it once + + async with concurrent(streamqueue): + result = await async_inference_detector(model, img) + + # visualize the results in a new window + show_result(img, result, model.CLASSES) + # or save the visualization results to image files + show_result(img, result, model.CLASSES, out_file='result.jpg') + + +asyncio.run(main()) + +``` + ## Train a model diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py index f85c7af686996d91faa1939e252479ac699e5d39..4cdf847b25d4f11b5d19ce28c3c35253035bbc12 100644 --- a/mmdet/apis/__init__.py +++ b/mmdet/apis/__init__.py @@ -1,9 +1,10 @@ from .env import get_root_logger, init_dist, set_random_seed -from .inference import (inference_detector, init_detector, show_result, - show_result_pyplot) +from .inference import (async_inference_detector, inference_detector, + init_detector, show_result, show_result_pyplot) from .train import train_detector __all__ = [ - 'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', - 'init_detector', 'inference_detector', 'show_result', 'show_result_pyplot' + 'async_inference_detector', 'init_dist', 'get_root_logger', + 'set_random_seed', 'train_detector', 'init_detector', 'inference_detector', + 'show_result', 'show_result_pyplot' ] diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index ccf228a77eb42266dd5e61cc7860ca4af34d7036..6724e85c85009366cef8c070a466f8d1c2b65ff8 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -84,7 +84,34 @@ def inference_detector(model, img): # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) + return result + + +async def async_inference_detector(model, img): + """Async inference image(s) with the detector. + + Args: + model (nn.Module): The loaded detector. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + Awaitable detection results. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img) + data = test_pipeline(data) + data = scatter(collate([data], samples_per_gpu=1), [device])[0] + # We don't restore `torch.is_grad_enabled()` value during concurrent + # inference since execution can overlap + torch.set_grad_enabled(False) + result = await model.aforward_test(rescale=True, **data) return result diff --git a/mmdet/models/anchor_heads/rpn_head.py b/mmdet/models/anchor_heads/rpn_head.py index 50f1cc515e05f0fa4e485960db2035156c9434e7..f88b949cf8b3051610697d15772bf1b7ea938a06 100644 --- a/mmdet/models/anchor_heads/rpn_head.py +++ b/mmdet/models/anchor_heads/rpn_head.py @@ -65,7 +65,6 @@ class RPNHead(AnchorHead): rpn_cls_score = cls_scores[idx] rpn_bbox_pred = bbox_preds[idx] assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] - anchors = mlvl_anchors[idx] rpn_cls_score = rpn_cls_score.permute(1, 2, 0) if self.use_sigmoid_cls: rpn_cls_score = rpn_cls_score.reshape(-1) @@ -74,6 +73,7 @@ class RPNHead(AnchorHead): rpn_cls_score = rpn_cls_score.reshape(-1, 2) scores = rpn_cls_score.softmax(dim=1)[:, 1] rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) + anchors = mlvl_anchors[idx] if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: _, topk_inds = scores.topk(cfg.nms_pre) rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 42f47dde8ed5043dd2c2ccc9ec11836ea749028c..6a33dab4c59ec8551dc3fa5e5a8965954c063d56 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -61,6 +61,10 @@ class BaseDetector(nn.Module): """ pass + @abstractmethod + async def async_simple_test(self, img, img_meta, **kwargs): + pass + @abstractmethod def simple_test(self, img, img_meta, **kwargs): pass @@ -74,6 +78,26 @@ class BaseDetector(nn.Module): logger = logging.getLogger() logger.info('load model from: {}'.format(pretrained)) + async def aforward_test(self, *, img, img_meta, **kwargs): + for var, name in [(img, 'img'), (img_meta, 'img_meta')]: + if not isinstance(var, list): + raise TypeError('{} must be a list, but got {}'.format( + name, type(var))) + + num_augs = len(img) + if num_augs != len(img_meta): + raise ValueError( + 'num of augmentations ({}) != num of image meta ({})'.format( + len(img), len(img_meta))) + # TODO: remove the restriction of imgs_per_gpu == 1 when prepared + imgs_per_gpu = img[0].size(0) + assert imgs_per_gpu == 1 + + if num_augs == 1: + return await self.async_simple_test(img[0], img_meta[0], **kwargs) + else: + raise NotImplementedError + def forward_test(self, imgs, img_metas, **kwargs): """ Args: diff --git a/mmdet/models/detectors/test_mixins.py b/mmdet/models/detectors/test_mixins.py index 05b2081271e09ed9c845c593cbe4946a5ab31e97..84a96d1679044c658cc6b7014006ef73b210054a 100644 --- a/mmdet/models/detectors/test_mixins.py +++ b/mmdet/models/detectors/test_mixins.py @@ -1,11 +1,33 @@ +import logging +import sys + import torch from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes, merge_aug_masks, merge_aug_proposals, multiclass_nms) +logger = logging.getLogger(__name__) + +if sys.version_info >= (3, 7): + from mmdet.utils.contextmanagers import completed + class RPNTestMixin(object): + if sys.version_info >= (3, 7): + + async def async_test_rpn(self, x, img_meta, rpn_test_cfg): + sleep_interval = rpn_test_cfg.pop("async_sleep_interval", 0.025) + async with completed( + __name__, "rpn_head_forward", + sleep_interval=sleep_interval): + rpn_outs = self.rpn_head(x) + + proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) + + proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) + return proposal_list + def simple_test_rpn(self, x, img_meta, rpn_test_cfg): rpn_outs = self.rpn_head(x) proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) @@ -37,6 +59,41 @@ class RPNTestMixin(object): class BBoxTestMixin(object): + if sys.version_info >= (3, 7): + + async def async_test_bboxes(self, + x, + img_meta, + proposals, + rcnn_test_cfg, + rescale=False, + bbox_semaphore=None, + global_lock=None): + """Async test only det bboxes without augmentation.""" + rois = bbox2roi(proposals) + roi_feats = self.bbox_roi_extractor( + x[:len(self.bbox_roi_extractor.featmap_strides)], rois) + if self.with_shared_head: + roi_feats = self.shared_head(roi_feats) + sleep_interval = rcnn_test_cfg.get("async_sleep_interval", 0.017) + + async with completed( + __name__, "bbox_head_forward", + sleep_interval=sleep_interval): + cls_score, bbox_pred = self.bbox_head(roi_feats) + + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + det_bboxes, det_labels = self.bbox_head.get_det_bboxes( + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=rescale, + cfg=rcnn_test_cfg) + return det_bboxes, det_labels + def simple_test_bboxes(self, x, img_meta, @@ -102,6 +159,46 @@ class BBoxTestMixin(object): class MaskTestMixin(object): + if sys.version_info >= (3, 7): + + async def async_test_mask(self, + x, + img_meta, + det_bboxes, + det_labels, + rescale=False, + mask_test_cfg=None): + # image shape of the first image in the batch (only one) + ori_shape = img_meta[0]['ori_shape'] + scale_factor = img_meta[0]['scale_factor'] + if det_bboxes.shape[0] == 0: + segm_result = [[] + for _ in range(self.mask_head.num_classes - 1)] + else: + _bboxes = ( + det_bboxes[:, :4] * + scale_factor if rescale else det_bboxes) + mask_rois = bbox2roi([_bboxes]) + mask_feats = self.mask_roi_extractor( + x[:len(self.mask_roi_extractor.featmap_strides)], + mask_rois) + + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'): + sleep_interval = mask_test_cfg['async_sleep_interval'] + else: + sleep_interval = 0.035 + async with completed( + __name__, + "mask_head_forward", + sleep_interval=sleep_interval): + mask_pred = self.mask_head(mask_feats) + segm_result = self.mask_head.get_seg_masks( + mask_pred, _bboxes, det_labels, self.test_cfg.rcnn, + ori_shape, scale_factor, rescale) + return segm_result + def simple_test_mask(self, x, img_meta, diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index 6a75eb3640f4ee552ca84b107f7ed3b56fb77f54..155819518a2879eeb1fadf57c177f74deb22c48d 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -260,14 +260,49 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, return losses + async def async_simple_test(self, + img, + img_meta, + proposals=None, + rescale=False): + """Async test without augmentation.""" + assert self.with_bbox, "Bbox head must be implemented." + x = self.extract_feat(img) + + if proposals is None: + proposal_list = await self.async_test_rpn(x, img_meta, + self.test_cfg.rpn) + else: + proposal_list = proposals + + det_bboxes, det_labels = await self.async_test_bboxes( + x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) + bbox_results = bbox2result(det_bboxes, det_labels, + self.bbox_head.num_classes) + + if not self.with_mask: + return bbox_results + else: + segm_results = await self.async_test_mask( + x, + img_meta, + det_bboxes, + det_labels, + rescale=rescale, + mask_test_cfg=self.test_cfg.get('mask')) + return bbox_results, segm_results + def simple_test(self, img, img_meta, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, "Bbox head must be implemented." x = self.extract_feat(img) - proposal_list = self.simple_test_rpn( - x, img_meta, self.test_cfg.rpn) if proposals is None else proposals + if proposals is None: + proposal_list = self.simple_test_rpn(x, img_meta, + self.test_cfg.rpn) + else: + proposal_list = proposals det_bboxes, det_labels = self.simple_test_bboxes( x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) diff --git a/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu b/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu index 402b4499cae114f2d2b30f01f182a5198c910298..e7a26f2e830846f80272bcd8c5ce0def34593c95 100644 --- a/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu +++ b/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -61,6 +61,7 @@ // modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu #include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> #include <THC/THCAtomics.cuh> #include <stdio.h> #include <math.h> @@ -261,7 +262,7 @@ void deformable_im2col( const scalar_t *data_offset_ = data_offset.data<scalar_t>(); scalar_t *data_col_ = data_col.data<scalar_t>(); - deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>( + deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, channels, deformable_group, @@ -355,7 +356,7 @@ void deformable_col2im( const scalar_t *data_offset_ = data_offset.data<scalar_t>(); scalar_t *grad_im_ = grad_im.data<scalar_t>(); - deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>( + deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, @@ -454,7 +455,7 @@ void deformable_col2im_coord( const scalar_t *data_offset_ = data_offset.data<scalar_t>(); scalar_t *grad_offset_ = grad_offset.data<scalar_t>(); - deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>( + deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( num_kernels, data_col_, data_im_, data_offset_, channels, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, @@ -784,7 +785,7 @@ void modulated_deformable_im2col_cuda( const scalar_t *data_mask_ = data_mask.data<scalar_t>(); scalar_t *data_col_ = data_col.data<scalar_t>(); - modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>( + modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, width_col, data_col_); @@ -816,7 +817,7 @@ void modulated_deformable_col2im_cuda( const scalar_t *data_mask_ = data_mask.data<scalar_t>(); scalar_t *grad_im_ = grad_im.data<scalar_t>(); - modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>( + modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, @@ -851,7 +852,7 @@ void modulated_deformable_col2im_coord_cuda( scalar_t *grad_offset_ = grad_offset.data<scalar_t>(); scalar_t *grad_mask_ = grad_mask.data<scalar_t>(); - modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>( + modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, diff --git a/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu b/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu index 1922d72442253353b29200a00a204b34a9153fa4..05b00d4be618353b404540469bf6118902651ca2 100644 --- a/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu +++ b/mmdet/ops/dcn/src/deform_pool_cuda_kernel.cu @@ -296,7 +296,7 @@ void DeformablePSROIPoolForward(const at::Tensor data, scalar_t *top_data = out.data<scalar_t>(); scalar_t *top_count_data = top_count.data<scalar_t>(); - DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>( + DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim, group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); @@ -349,7 +349,7 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>(); const scalar_t *top_count_data = top_count.data<scalar_t>(); - DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>( + DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>( count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, @@ -361,4 +361,4 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, { printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); } -} \ No newline at end of file +} diff --git a/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu b/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu index 2312d1200c926d7c3ecc5fc7655b344f63de509e..81c785bbe41461fa8a4d380dbbef60dbe677cf6a 100644 --- a/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu +++ b/mmdet/ops/masked_conv/src/masked_conv2d_kernel.cu @@ -1,4 +1,5 @@ #include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> #include <THC/THCAtomics.cuh> #define CUDA_1D_KERNEL_LOOP(i, n) \ @@ -63,7 +64,8 @@ int MaskedIm2colForwardLaucher(const at::Tensor bottom_data, const int height, const int64_t *mask_w_idx_ = mask_w_idx.data<int64_t>(); scalar_t *top_data_ = top_data.data<scalar_t>(); MaskedIm2colForward<scalar_t> - <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream() +>>>( output_size, bottom_data_, height, width, kernel_h, kernel_w, pad_h, pad_w, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_); })); @@ -103,7 +105,7 @@ int MaskedCol2imForwardLaucher(const at::Tensor bottom_data, const int height, scalar_t *top_data_ = top_data.data<scalar_t>(); MaskedCol2imForward<scalar_t> - <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>( output_size, bottom_data_, height, width, channels, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_); })); diff --git a/mmdet/ops/nms/src/nms_kernel.cu b/mmdet/ops/nms/src/nms_kernel.cu index c0811cfcb4802355f5784d1e737374a4b08ebb55..ada9bea25237046b4001fea0e0c7061c6a53886f 100644 --- a/mmdet/ops/nms/src/nms_kernel.cu +++ b/mmdet/ops/nms/src/nms_kernel.cu @@ -96,16 +96,19 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), THCCeilDiv(boxes_num, threadsPerBlock)); dim3 threads(threadsPerBlock); - nms_kernel<<<blocks, threads>>>(boxes_num, + nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num, nms_overlap_thresh, boxes_dev, mask_dev); std::vector<unsigned long long> mask_host(boxes_num * col_blocks); - THCudaCheck(cudaMemcpy(&mask_host[0], - mask_dev, - sizeof(unsigned long long) * boxes_num * col_blocks, - cudaMemcpyDeviceToHost)); + THCudaCheck(cudaMemcpyAsync( + &mask_host[0], + mask_dev, + sizeof(unsigned long long) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost, + at::cuda::getCurrentCUDAStream() + )); std::vector<unsigned long long> remv(col_blocks); memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); diff --git a/mmdet/ops/roi_align/src/roi_align_cuda.cpp b/mmdet/ops/roi_align/src/roi_align_cuda.cpp index 06a73aa05e49b547cc734af3e223036edb29f3b2..12c16f009d7dc9e8b0298beb5385e8cf0256becc 100644 --- a/mmdet/ops/roi_align/src/roi_align_cuda.cpp +++ b/mmdet/ops/roi_align/src/roi_align_cuda.cpp @@ -1,5 +1,7 @@ #include <torch/extension.h> +#include <ATen/ATen.h> + #include <cmath> #include <vector> diff --git a/mmdet/ops/roi_align/src/roi_align_kernel.cu b/mmdet/ops/roi_align/src/roi_align_kernel.cu index eb7cdaf1ffa5bc1ff25f1c873d0f23e9590b0912..3208b2806155cbdd46e1c9230ad0c4db03d33f51 100644 --- a/mmdet/ops/roi_align/src/roi_align_kernel.cu +++ b/mmdet/ops/roi_align/src/roi_align_kernel.cu @@ -1,4 +1,5 @@ #include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> #include <THC/THCAtomics.cuh> #define CUDA_1D_KERNEL_LOOP(i, n) \ @@ -131,7 +132,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois, scalar_t *top_data = output.data<scalar_t>(); ROIAlignForward<scalar_t> - <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>( output_size, bottom_data, rois_data, scalar_t(spatial_scale), sample_num, channels, height, width, pooled_height, pooled_width, top_data); @@ -272,7 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, } ROIAlignBackward<scalar_t> - <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>( output_size, top_diff, rois_data, spatial_scale, sample_num, channels, height, width, pooled_height, pooled_width, bottom_diff); diff --git a/mmdet/ops/roi_pool/src/roi_pool_kernel.cu b/mmdet/ops/roi_pool/src/roi_pool_kernel.cu index 25ba98532850422f3d54fc3b00c35ebdf1dd8ae3..37c3d0c4926696bbf3ca7cbad37ac4991ef80055 100644 --- a/mmdet/ops/roi_pool/src/roi_pool_kernel.cu +++ b/mmdet/ops/roi_pool/src/roi_pool_kernel.cu @@ -1,4 +1,5 @@ #include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> #include <THC/THCAtomics.cuh> #define CUDA_1D_KERNEL_LOOP(i, n) \ @@ -93,7 +94,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois, int *argmax_data = argmax.data<int>(); ROIPoolForward<scalar_t> - <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>( output_size, bottom_data, rois_data, scalar_t(spatial_scale), channels, height, width, pooled_h, pooled_w, top_data, argmax_data); @@ -146,7 +147,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, } ROIPoolBackward<scalar_t> - <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>( + <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>( output_size, top_diff, rois_data, argmax_data, scalar_t(spatial_scale), channels, height, width, pooled_h, pooled_w, bottom_diff); diff --git a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu index 6a9104c1cd9ce1e817e1b75791fd822f819d9a85..0e152d38fd15858c4b8c51936aa68f78796a319a 100644 --- a/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu +++ b/mmdet/ops/sigmoid_focal_loss/src/sigmoid_focal_loss_cuda.cu @@ -120,7 +120,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits, AT_DISPATCH_FLOATING_TYPES_AND_HALF( logits.scalar_type(), "SigmoidFocalLoss_forward", [&] { - SigmoidFocalLossForward<scalar_t><<<grid, block>>>( + SigmoidFocalLossForward<scalar_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( losses_size, logits.contiguous().data<scalar_t>(), targets.contiguous().data<int64_t>(), num_classes, gamma, alpha, num_samples, losses.data<scalar_t>()); @@ -159,7 +159,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits, AT_DISPATCH_FLOATING_TYPES_AND_HALF( logits.scalar_type(), "SigmoidFocalLoss_backward", [&] { - SigmoidFocalLossBackward<scalar_t><<<grid, block>>>( + SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( d_logits_size, logits.contiguous().data<scalar_t>(), targets.contiguous().data<int64_t>(), d_losses.contiguous().data<scalar_t>(), num_classes, gamma, alpha, diff --git a/mmdet/utils/contextmanagers.py b/mmdet/utils/contextmanagers.py new file mode 100644 index 0000000000000000000000000000000000000000..12073bef93219fedd96713aa7c6452c1530679a5 --- /dev/null +++ b/mmdet/utils/contextmanagers.py @@ -0,0 +1,122 @@ +# coding: utf-8 +import asyncio +import contextlib +import logging +import os +import time +from typing import List + +import torch + +logger = logging.getLogger(__name__) + +DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False)) + + +@contextlib.asynccontextmanager +async def completed(trace_name="", + name="", + sleep_interval=0.05, + streams: List[torch.cuda.Stream] = None): + """ + Async context manager that waits for work to complete on + given CUDA streams. + + """ + if not torch.cuda.is_available(): + yield + return + + stream_before_context_switch = torch.cuda.current_stream() + if not streams: + streams = [stream_before_context_switch] + else: + streams = [s if s else stream_before_context_switch for s in streams] + + end_events = [ + torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams + ] + + if DEBUG_COMPLETED_TIME: + start = torch.cuda.Event(enable_timing=True) + stream_before_context_switch.record_event(start) + + cpu_start = time.monotonic() + logger.debug("%s %s starting, streams: %s", trace_name, name, streams) + grad_enabled_before = torch.is_grad_enabled() + try: + yield + finally: + current_stream = torch.cuda.current_stream() + assert current_stream == stream_before_context_switch + + if DEBUG_COMPLETED_TIME: + cpu_end = time.monotonic() + for i, stream in enumerate(streams): + event = end_events[i] + stream.record_event(event) + + grad_enabled_after = torch.is_grad_enabled() + + # observed change of torch.is_grad_enabled() during concurrent run of + # async_test_bboxes code + assert grad_enabled_before == grad_enabled_after, \ + "Unexpected is_grad_enabled() value change" + + are_done = [e.query() for e in end_events] + logger.debug("%s %s completed: %s streams: %s", trace_name, name, + are_done, streams) + with torch.cuda.stream(stream_before_context_switch): + while not all(are_done): + await asyncio.sleep(sleep_interval) + are_done = [e.query() for e in end_events] + logger.debug("%s %s completed: %s streams: %s", trace_name, + name, are_done, streams) + + current_stream = torch.cuda.current_stream() + assert current_stream == stream_before_context_switch + + if DEBUG_COMPLETED_TIME: + cpu_time = (cpu_end - cpu_start) * 1000 + stream_times_ms = "" + for i, stream in enumerate(streams): + elapsed_time = start.elapsed_time(end_events[i]) + stream_times_ms += " {stream} {elapsed_time:.2f} ms".format( + stream, elapsed_time) + logger.info("{trace_name} {name} cpu_time {cpu_time:.2f} ms", + trace_name, name, cpu_time, stream_times_ms) + + +@contextlib.asynccontextmanager +async def concurrent(streamqueue: asyncio.Queue, + trace_name="concurrent", + name="stream"): + """Run code concurrently in different streams. + + :param streamqueue: asyncio.Queue instance. + + Queue tasks define the pool of streams used for concurrent execution. + + """ + if not torch.cuda.is_available(): + yield + return + + initial_stream = torch.cuda.current_stream() + + with torch.cuda.stream(initial_stream): + stream = await streamqueue.get() + assert isinstance(stream, torch.cuda.Stream) + + try: + with torch.cuda.stream(stream): + logger.debug("%s %s is starting, stream: %s", trace_name, name, + stream) + yield + current = torch.cuda.current_stream() + assert current == stream + logger.debug("%s %s has finished, stream: %s", trace_name, + name, stream) + finally: + streamqueue.task_done() + streamqueue.put_nowait(stream) diff --git a/mmdet/utils/profiling.py b/mmdet/utils/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..58b1c87dd3f7a427bfa59b003b58e29158666d36 --- /dev/null +++ b/mmdet/utils/profiling.py @@ -0,0 +1,41 @@ +import contextlib +import sys +import time + +import torch + +if sys.version_info >= (3, 7): + + @contextlib.contextmanager + def profile_time(trace_name, + name, + enabled=True, + stream=None, + end_stream=None): + """Print time spent by CPU and GPU. + + Useful as a temporary context manager to find sweet spots of + code suitable for async implementation. + + """ + if (not enabled) or not torch.cuda.is_available(): + yield + return + stream = stream if stream else torch.cuda.current_stream() + end_stream = end_stream if end_stream else stream + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + stream.record_event(start) + try: + cpu_start = time.monotonic() + yield + finally: + cpu_end = time.monotonic() + end_stream.record_event(end) + end.synchronize() + cpu_time = (cpu_end - cpu_start) * 1000 + gpu_time = start.elapsed_time(end) + msg = "{} {} cpu_time {:.2f} ms ".format(trace_name, name, + cpu_time) + msg += "gpu_time {:.2f} ms stream {}".format(gpu_time, stream) + print(msg, end_stream) diff --git a/setup.py b/setup.py index 43a2ec219241d62854c027de1345fcddf6dfc02b..54d58777268ae33f0da7cdc7135f9f4969e5a1d1 100644 --- a/setup.py +++ b/setup.py @@ -162,7 +162,7 @@ if __name__ == '__main__': ], license='Apache License 2.0', setup_requires=['pytest-runner', 'cython', 'numpy'], - tests_require=['pytest', 'xdoctest'], + tests_require=['pytest', 'xdoctest', 'asynctest'], install_requires=get_requirements(), ext_modules=[ make_cuda_ext( diff --git a/tests/async_benchmark.py b/tests/async_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..0017783d354d24a0c9f38953a7566f7545a41bb6 --- /dev/null +++ b/tests/async_benchmark.py @@ -0,0 +1,104 @@ +# coding: utf-8 + +import asyncio +import os +import shutil +import urllib + +import mmcv +import torch + +from mmdet.apis import (async_inference_detector, inference_detector, + init_detector, show_result) +from mmdet.utils.contextmanagers import concurrent +from mmdet.utils.profiling import profile_time + + +async def main(): + """ + + Benchmark between async and synchronous inference interfaces. + + Sample runs for 20 demo images on K80 GPU, model - mask_rcnn_r50_fpn_1x: + + async sync + + 7981.79 ms 9660.82 ms + 8074.52 ms 9660.94 ms + 7976.44 ms 9406.83 ms + + Async variant takes about 0.83-0.85 of the time of the synchronous + interface. + + """ + project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + + config_file = os.path.join(project_dir, 'configs/mask_rcnn_r50_fpn_1x.py') + checkpoint_file = os.path.join( + project_dir, 'checkpoints/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth') + + if not os.path.exists(checkpoint_file): + url = ('https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection' + '/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth') + print('Downloading {} ...'.format(url)) + local_filename, _ = urllib.request.urlretrieve(url) + os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True) + shutil.move(local_filename, checkpoint_file) + print('Saved as {}'.format(checkpoint_file)) + else: + print('Using existing checkpoint {}'.format(checkpoint_file)) + + device = 'cuda:0' + model = init_detector( + config_file, checkpoint=checkpoint_file, device=device) + + # queue is used for concurrent inference of multiple images + streamqueue = asyncio.Queue() + # queue size defines concurrency level + streamqueue_size = 4 + + for _ in range(streamqueue_size): + streamqueue.put_nowait(torch.cuda.Stream(device=device)) + + # test a single image and show the results + img = mmcv.imread(os.path.join(project_dir, 'demo/demo.jpg')) + + # warmup + await async_inference_detector(model, img) + + async def detect(img): + async with concurrent(streamqueue): + return await async_inference_detector(model, img) + + num_of_images = 20 + with profile_time('benchmark', 'async'): + tasks = [ + asyncio.create_task(detect(img)) for _ in range(num_of_images) + ] + async_results = await asyncio.gather(*tasks) + + with torch.cuda.stream(torch.cuda.default_stream()): + with profile_time('benchmark', 'sync'): + sync_results = [ + inference_detector(model, img) for _ in range(num_of_images) + ] + + result_dir = os.path.join(project_dir, 'demo') + show_result( + img, + async_results[0], + model.CLASSES, + score_thr=0.5, + show=False, + out_file=os.path.join(result_dir, 'result_async.jpg')) + show_result( + img, + sync_results[0], + model.CLASSES, + score_thr=0.5, + show=False, + out_file=os.path.join(result_dir, 'result_sync.jpg')) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/tests/requirements.txt b/tests/requirements.txt index a1f3efbbe93d05ef4e1eceeb8d5640b56daf1d3e..ff60968eef8a22c8f8fa98c090b5ce9702ecd6b4 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -4,3 +4,4 @@ yapf pytest-cov codecov xdoctest >= 0.10.0 +asynctest diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000000000000000000000000000000000000..68ecde33d0b5d937041060275b7c23dbbc471a16 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,78 @@ +"""Tests for async interface.""" + +import asyncio +import os +import sys + +import asynctest +import mmcv +import torch + +from mmdet.apis import async_inference_detector, init_detector + +if sys.version_info >= (3, 7): + from mmdet.utils.contextmanagers import concurrent + + +class AsyncTestCase(asynctest.TestCase): + use_default_loop = False + forbid_get_event_loop = True + + TEST_TIMEOUT = int(os.getenv("ASYNCIO_TEST_TIMEOUT", "30")) + + def _run_test_method(self, method): + result = method() + if asyncio.iscoroutine(result): + self.loop.run_until_complete( + asyncio.wait_for(result, timeout=self.TEST_TIMEOUT)) + + +class MaskRCNNDetector: + + def __init__(self, + model_config, + checkpoint=None, + streamqueue_size=3, + device="cuda:0"): + + self.streamqueue_size = streamqueue_size + self.device = device + # build the model and load checkpoint + self.model = init_detector( + model_config, checkpoint=None, device=self.device) + self.streamqueue = None + + async def init(self): + self.streamqueue = asyncio.Queue() + for _ in range(self.streamqueue_size): + stream = torch.cuda.Stream(device=self.device) + self.streamqueue.put_nowait(stream) + + if sys.version_info >= (3, 7): + + async def apredict(self, img): + if isinstance(img, str): + img = mmcv.imread(img) + async with concurrent(self.streamqueue): + result = await async_inference_detector(self.model, img) + return result + + +class AsyncInferenceTestCase(AsyncTestCase): + + if sys.version_info >= (3, 7): + + async def test_simple_inference(self): + if not torch.cuda.is_available(): + import pytest + + pytest.skip("test requires GPU and torch+cuda") + + root_dir = os.path.dirname(os.path.dirname(__name__)) + model_config = os.path.join(root_dir, + "configs/mask_rcnn_r50_fpn_1x.py") + detector = MaskRCNNDetector(model_config) + await detector.init() + img_path = os.path.join(root_dir, "demo/demo.jpg") + bboxes, _ = await detector.apredict(img_path) + self.assertTrue(bboxes)