diff --git a/mmdet/ops/roi_align/functions/roi_align.py b/mmdet/ops/roi_align/functions/roi_align.py index 096badd25673b6b46b3ccb36952e021d00cab835..cd2ee9edd10896f2b1728b28b1d47d4883176980 100644 --- a/mmdet/ops/roi_align/functions/roi_align.py +++ b/mmdet/ops/roi_align/functions/roi_align.py @@ -1,4 +1,5 @@ from torch.autograd import Function +from torch.nn.modules.utils import _pair from .. import roi_align_cuda @@ -7,17 +8,8 @@ class RoIAlignFunction(Function): @staticmethod def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0): - if isinstance(out_size, int): - out_h = out_size - out_w = out_size - elif isinstance(out_size, tuple): - assert len(out_size) == 2 - assert isinstance(out_size[0], int) - assert isinstance(out_size[1], int) - out_h, out_w = out_size - else: - raise TypeError( - '"out_size" must be an integer or tuple of integers') + out_h, out_w = _pair(out_size) + assert isinstance(out_h, int) and isinstance(out_w, int) ctx.spatial_scale = spatial_scale ctx.sample_num = sample_num ctx.save_for_backward(rois) diff --git a/mmdet/ops/roi_align/modules/roi_align.py b/mmdet/ops/roi_align/modules/roi_align.py index b83b74e6b7c151eaf627c2b6d3530823ce8cda05..de987bd456c88a093632a96b0fcc57b2a3190e87 100644 --- a/mmdet/ops/roi_align/modules/roi_align.py +++ b/mmdet/ops/roi_align/modules/roi_align.py @@ -1,16 +1,28 @@ -from torch.nn.modules.module import Module -from ..functions.roi_align import RoIAlignFunction +import torch.nn as nn +from torch.nn.modules.utils import _pair +from ..functions.roi_align import roi_align -class RoIAlign(Module): - def __init__(self, out_size, spatial_scale, sample_num=0): +class RoIAlign(nn.Module): + + def __init__(self, + out_size, + spatial_scale, + sample_num=0, + use_torchvision=False): super(RoIAlign, self).__init__() self.out_size = out_size self.spatial_scale = float(spatial_scale) self.sample_num = int(sample_num) + self.use_torchvision = use_torchvision def forward(self, features, rois): - return RoIAlignFunction.apply(features, rois, self.out_size, - self.spatial_scale, self.sample_num) + if self.use_torchvision: + from torchvision.ops import roi_align as tv_roi_align + return tv_roi_align(features, rois, _pair(self.out_size), + self.spatial_scale, self.sample_num) + else: + return roi_align(features, rois, self.out_size, self.spatial_scale, + self.sample_num) diff --git a/mmdet/ops/roi_pool/functions/roi_pool.py b/mmdet/ops/roi_pool/functions/roi_pool.py index 068da600e5828d88ef1477c1afe19b81ee363ee2..6de40088c62828f917937a12d9ed2708ce2b85c3 100644 --- a/mmdet/ops/roi_pool/functions/roi_pool.py +++ b/mmdet/ops/roi_pool/functions/roi_pool.py @@ -1,5 +1,6 @@ import torch from torch.autograd import Function +from torch.nn.modules.utils import _pair from .. import roi_pool_cuda @@ -8,18 +9,9 @@ class RoIPoolFunction(Function): @staticmethod def forward(ctx, features, rois, out_size, spatial_scale): - if isinstance(out_size, int): - out_h = out_size - out_w = out_size - elif isinstance(out_size, tuple): - assert len(out_size) == 2 - assert isinstance(out_size[0], int) - assert isinstance(out_size[1], int) - out_h, out_w = out_size - else: - raise TypeError( - '"out_size" must be an integer or tuple of integers') assert features.is_cuda + out_h, out_w = _pair(out_size) + assert isinstance(out_h, int) and isinstance(out_w, int) ctx.save_for_backward(rois) num_channels = features.size(1) num_rois = rois.size(0) diff --git a/mmdet/ops/roi_pool/modules/roi_pool.py b/mmdet/ops/roi_pool/modules/roi_pool.py index d7fffd08c656ee7301aeed5a8262714f4be4157d..c173cbbfd7e9c4f8a7f5cdedc4258fa7e2ccbad2 100644 --- a/mmdet/ops/roi_pool/modules/roi_pool.py +++ b/mmdet/ops/roi_pool/modules/roi_pool.py @@ -1,14 +1,22 @@ -from torch.nn.modules.module import Module +import torch.nn as nn +from torch.nn.modules.utils import _pair + from ..functions.roi_pool import roi_pool -class RoIPool(Module): +class RoIPool(nn.Module): - def __init__(self, out_size, spatial_scale): + def __init__(self, out_size, spatial_scale, use_torchvision=False): super(RoIPool, self).__init__() self.out_size = out_size self.spatial_scale = float(spatial_scale) + self.use_torchvision = use_torchvision def forward(self, features, rois): - return roi_pool(features, rois, self.out_size, self.spatial_scale) + if self.use_torchvision: + from torchvision.ops import roi_pool as tv_roi_pool + return tv_roi_pool(features, rois, _pair(self.out_size), + self.spatial_scale) + else: + return roi_pool(features, rois, self.out_size, self.spatial_scale)