Skip to content
Snippets Groups Projects
Unverified Commit 350fdd7a authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

support torchvision RoIPool and RoIAlign (#990)

parent c101398c
No related branches found
No related tags found
No related merge requests found
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import roi_align_cuda
......@@ -7,17 +8,8 @@ class RoIAlignFunction(Function):
@staticmethod
def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0):
if isinstance(out_size, int):
out_h = out_size
out_w = out_size
elif isinstance(out_size, tuple):
assert len(out_size) == 2
assert isinstance(out_size[0], int)
assert isinstance(out_size[1], int)
out_h, out_w = out_size
else:
raise TypeError(
'"out_size" must be an integer or tuple of integers')
out_h, out_w = _pair(out_size)
assert isinstance(out_h, int) and isinstance(out_w, int)
ctx.spatial_scale = spatial_scale
ctx.sample_num = sample_num
ctx.save_for_backward(rois)
......
from torch.nn.modules.module import Module
from ..functions.roi_align import RoIAlignFunction
import torch.nn as nn
from torch.nn.modules.utils import _pair
from ..functions.roi_align import roi_align
class RoIAlign(Module):
def __init__(self, out_size, spatial_scale, sample_num=0):
class RoIAlign(nn.Module):
def __init__(self,
out_size,
spatial_scale,
sample_num=0,
use_torchvision=False):
super(RoIAlign, self).__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)
self.sample_num = int(sample_num)
self.use_torchvision = use_torchvision
def forward(self, features, rois):
return RoIAlignFunction.apply(features, rois, self.out_size,
self.spatial_scale, self.sample_num)
if self.use_torchvision:
from torchvision.ops import roi_align as tv_roi_align
return tv_roi_align(features, rois, _pair(self.out_size),
self.spatial_scale, self.sample_num)
else:
return roi_align(features, rois, self.out_size, self.spatial_scale,
self.sample_num)
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import roi_pool_cuda
......@@ -8,18 +9,9 @@ class RoIPoolFunction(Function):
@staticmethod
def forward(ctx, features, rois, out_size, spatial_scale):
if isinstance(out_size, int):
out_h = out_size
out_w = out_size
elif isinstance(out_size, tuple):
assert len(out_size) == 2
assert isinstance(out_size[0], int)
assert isinstance(out_size[1], int)
out_h, out_w = out_size
else:
raise TypeError(
'"out_size" must be an integer or tuple of integers')
assert features.is_cuda
out_h, out_w = _pair(out_size)
assert isinstance(out_h, int) and isinstance(out_w, int)
ctx.save_for_backward(rois)
num_channels = features.size(1)
num_rois = rois.size(0)
......
from torch.nn.modules.module import Module
import torch.nn as nn
from torch.nn.modules.utils import _pair
from ..functions.roi_pool import roi_pool
class RoIPool(Module):
class RoIPool(nn.Module):
def __init__(self, out_size, spatial_scale):
def __init__(self, out_size, spatial_scale, use_torchvision=False):
super(RoIPool, self).__init__()
self.out_size = out_size
self.spatial_scale = float(spatial_scale)
self.use_torchvision = use_torchvision
def forward(self, features, rois):
return roi_pool(features, rois, self.out_size, self.spatial_scale)
if self.use_torchvision:
from torchvision.ops import roi_pool as tv_roi_pool
return tv_roi_pool(features, rois, _pair(self.out_size),
self.spatial_scale)
else:
return roi_pool(features, rois, self.out_size, self.spatial_scale)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment