From 350fdd7aeab91cd16554a0be5bd249d2e34b222c Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Mon, 15 Jul 2019 13:38:12 +0800 Subject: [PATCH] support torchvision RoIPool and RoIAlign (#990) --- mmdet/ops/roi_align/functions/roi_align.py | 14 +++---------- mmdet/ops/roi_align/modules/roi_align.py | 24 ++++++++++++++++------ mmdet/ops/roi_pool/functions/roi_pool.py | 14 +++---------- mmdet/ops/roi_pool/modules/roi_pool.py | 16 +++++++++++---- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/mmdet/ops/roi_align/functions/roi_align.py b/mmdet/ops/roi_align/functions/roi_align.py index 096badd..cd2ee9e 100644 --- a/mmdet/ops/roi_align/functions/roi_align.py +++ b/mmdet/ops/roi_align/functions/roi_align.py @@ -1,4 +1,5 @@ from torch.autograd import Function +from torch.nn.modules.utils import _pair from .. import roi_align_cuda @@ -7,17 +8,8 @@ class RoIAlignFunction(Function): @staticmethod def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0): - if isinstance(out_size, int): - out_h = out_size - out_w = out_size - elif isinstance(out_size, tuple): - assert len(out_size) == 2 - assert isinstance(out_size[0], int) - assert isinstance(out_size[1], int) - out_h, out_w = out_size - else: - raise TypeError( - '"out_size" must be an integer or tuple of integers') + out_h, out_w = _pair(out_size) + assert isinstance(out_h, int) and isinstance(out_w, int) ctx.spatial_scale = spatial_scale ctx.sample_num = sample_num ctx.save_for_backward(rois) diff --git a/mmdet/ops/roi_align/modules/roi_align.py b/mmdet/ops/roi_align/modules/roi_align.py index b83b74e..de987bd 100644 --- a/mmdet/ops/roi_align/modules/roi_align.py +++ b/mmdet/ops/roi_align/modules/roi_align.py @@ -1,16 +1,28 @@ -from torch.nn.modules.module import Module -from ..functions.roi_align import RoIAlignFunction +import torch.nn as nn +from torch.nn.modules.utils import _pair +from ..functions.roi_align import roi_align -class RoIAlign(Module): - def __init__(self, out_size, spatial_scale, sample_num=0): +class RoIAlign(nn.Module): + + def __init__(self, + out_size, + spatial_scale, + sample_num=0, + use_torchvision=False): super(RoIAlign, self).__init__() self.out_size = out_size self.spatial_scale = float(spatial_scale) self.sample_num = int(sample_num) + self.use_torchvision = use_torchvision def forward(self, features, rois): - return RoIAlignFunction.apply(features, rois, self.out_size, - self.spatial_scale, self.sample_num) + if self.use_torchvision: + from torchvision.ops import roi_align as tv_roi_align + return tv_roi_align(features, rois, _pair(self.out_size), + self.spatial_scale, self.sample_num) + else: + return roi_align(features, rois, self.out_size, self.spatial_scale, + self.sample_num) diff --git a/mmdet/ops/roi_pool/functions/roi_pool.py b/mmdet/ops/roi_pool/functions/roi_pool.py index 068da60..6de4008 100644 --- a/mmdet/ops/roi_pool/functions/roi_pool.py +++ b/mmdet/ops/roi_pool/functions/roi_pool.py @@ -1,5 +1,6 @@ import torch from torch.autograd import Function +from torch.nn.modules.utils import _pair from .. import roi_pool_cuda @@ -8,18 +9,9 @@ class RoIPoolFunction(Function): @staticmethod def forward(ctx, features, rois, out_size, spatial_scale): - if isinstance(out_size, int): - out_h = out_size - out_w = out_size - elif isinstance(out_size, tuple): - assert len(out_size) == 2 - assert isinstance(out_size[0], int) - assert isinstance(out_size[1], int) - out_h, out_w = out_size - else: - raise TypeError( - '"out_size" must be an integer or tuple of integers') assert features.is_cuda + out_h, out_w = _pair(out_size) + assert isinstance(out_h, int) and isinstance(out_w, int) ctx.save_for_backward(rois) num_channels = features.size(1) num_rois = rois.size(0) diff --git a/mmdet/ops/roi_pool/modules/roi_pool.py b/mmdet/ops/roi_pool/modules/roi_pool.py index d7fffd0..c173cbb 100644 --- a/mmdet/ops/roi_pool/modules/roi_pool.py +++ b/mmdet/ops/roi_pool/modules/roi_pool.py @@ -1,14 +1,22 @@ -from torch.nn.modules.module import Module +import torch.nn as nn +from torch.nn.modules.utils import _pair + from ..functions.roi_pool import roi_pool -class RoIPool(Module): +class RoIPool(nn.Module): - def __init__(self, out_size, spatial_scale): + def __init__(self, out_size, spatial_scale, use_torchvision=False): super(RoIPool, self).__init__() self.out_size = out_size self.spatial_scale = float(spatial_scale) + self.use_torchvision = use_torchvision def forward(self, features, rois): - return roi_pool(features, rois, self.out_size, self.spatial_scale) + if self.use_torchvision: + from torchvision.ops import roi_pool as tv_roi_pool + return tv_roi_pool(features, rois, _pair(self.out_size), + self.spatial_scale) + else: + return roi_pool(features, rois, self.out_size, self.spatial_scale) -- GitLab