diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py index 5f6ad0930c23b5a2cd6115e1c2e593afd5d29a22..c721925efa4cbcee8d594ebee624757d81d01b35 100644 --- a/mmdet/ops/__init__.py +++ b/mmdet/ops/__init__.py @@ -2,7 +2,7 @@ from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack, deform_conv, modulated_deform_conv, deform_roi_pooling) -from .gcb import ContextBlock +from .context_block import ContextBlock from .nms import nms, soft_nms from .roi_align import RoIAlign, roi_align from .roi_pool import RoIPool, roi_pool diff --git a/mmdet/ops/gcb/context_block.py b/mmdet/ops/context_block.py similarity index 100% rename from mmdet/ops/gcb/context_block.py rename to mmdet/ops/context_block.py diff --git a/mmdet/ops/dcn/__init__.py b/mmdet/ops/dcn/__init__.py index 165e63725354de429a448d866f665cccca991916..48566be193af852e4475ea0ebbdd2d38f0dcb549 100644 --- a/mmdet/ops/dcn/__init__.py +++ b/mmdet/ops/dcn/__init__.py @@ -1,9 +1,8 @@ -from .functions.deform_conv import deform_conv, modulated_deform_conv -from .functions.deform_pool import deform_roi_pooling -from .modules.deform_conv import (DeformConv, ModulatedDeformConv, - DeformConvPack, ModulatedDeformConvPack) -from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, - ModulatedDeformRoIPoolingPack) +from .deform_conv import (deform_conv, modulated_deform_conv, DeformConv, + DeformConvPack, ModulatedDeformConv, + ModulatedDeformConvPack) +from .deform_pool import (deform_roi_pooling, DeformRoIPooling, + DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack) __all__ = [ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', diff --git a/mmdet/ops/dcn/functions/deform_conv.py b/mmdet/ops/dcn/deform_conv.py similarity index 58% rename from mmdet/ops/dcn/functions/deform_conv.py rename to mmdet/ops/dcn/deform_conv.py index 6af75a758b8448ca1d981054525259f536d99d1e..7f6841a585e504c8a480d5ac134807ae63edcd5d 100644 --- a/mmdet/ops/dcn/functions/deform_conv.py +++ b/mmdet/ops/dcn/deform_conv.py @@ -1,8 +1,12 @@ +import math + import torch +import torch.nn as nn from torch.autograd import Function +from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from .. import deform_conv_cuda +from . import deform_conv_cuda class DeformConvFunction(Function): @@ -52,6 +56,7 @@ class DeformConvFunction(Function): return output @staticmethod + @once_differentiable def backward(ctx, grad_output): input, offset, weight = ctx.saved_tensors @@ -143,6 +148,7 @@ class ModulatedDeformConvFunction(Function): return output @staticmethod + @once_differentiable def backward(ctx, grad_output): if not grad_output.is_cuda: raise NotImplementedError @@ -179,3 +185,153 @@ class ModulatedDeformConvFunction(Function): deform_conv = DeformConvFunction.apply modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, \ + 'in_channels {} cannot be divisible by groups {}'.format( + in_channels, groups) + assert out_channels % groups == 0, \ + 'out_channels {} cannot be divisible by groups {}'.format( + out_channels, groups) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, + *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + return deform_conv(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +class DeformConvPack(DeformConv): + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, + *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset_mask = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset_mask(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/mmdet/ops/dcn/modules/deform_pool.py b/mmdet/ops/dcn/deform_pool.py similarity index 58% rename from mmdet/ops/dcn/modules/deform_pool.py rename to mmdet/ops/dcn/deform_pool.py index 5e0196753ee1b427263bc397e0ae842af6a9938b..2d09dec8c82fd01fc947e7258b5bbb54e3245d43 100644 --- a/mmdet/ops/dcn/modules/deform_pool.py +++ b/mmdet/ops/dcn/deform_pool.py @@ -1,6 +1,75 @@ -from torch import nn - -from ..functions.deform_pool import deform_roi_pooling +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from . import deform_pool_cuda + + +class DeformRoIPoolingFunction(Function): + + @staticmethod + def forward(ctx, + data, + rois, + offset, + spatial_scale, + out_size, + out_channels, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0): + ctx.spatial_scale = spatial_scale + ctx.out_size = out_size + ctx.out_channels = out_channels + ctx.no_trans = no_trans + ctx.group_size = group_size + ctx.part_size = out_size if part_size is None else part_size + ctx.sample_per_part = sample_per_part + ctx.trans_std = trans_std + + assert 0.0 <= ctx.trans_std <= 1.0 + if not data.is_cuda: + raise NotImplementedError + + n = rois.shape[0] + output = data.new_empty(n, out_channels, out_size, out_size) + output_count = data.new_empty(n, out_channels, out_size, out_size) + deform_pool_cuda.deform_psroi_pooling_cuda_forward( + data, rois, offset, output, output_count, ctx.no_trans, + ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, + ctx.part_size, ctx.sample_per_part, ctx.trans_std) + + if data.requires_grad or rois.requires_grad or offset.requires_grad: + ctx.save_for_backward(data, rois, offset) + ctx.output_count = output_count + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + + data, rois, offset = ctx.saved_tensors + output_count = ctx.output_count + grad_input = torch.zeros_like(data) + grad_rois = None + grad_offset = torch.zeros_like(offset) + + deform_pool_cuda.deform_psroi_pooling_cuda_backward( + grad_output, data, rois, offset, output_count, grad_input, + grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, + ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, + ctx.trans_std) + return (grad_input, grad_rois, grad_offset, None, None, None, None, + None, None, None, None) + + +deform_roi_pooling = DeformRoIPoolingFunction.apply class DeformRoIPooling(nn.Module): @@ -27,10 +96,11 @@ class DeformRoIPooling(nn.Module): def forward(self, data, rois, offset): if self.no_trans: offset = data.new_empty(0) - return deform_roi_pooling( - data, rois, offset, self.spatial_scale, self.out_size, - self.out_channels, self.no_trans, self.group_size, self.part_size, - self.sample_per_part, self.trans_std) + return deform_roi_pooling(data, rois, offset, self.spatial_scale, + self.out_size, self.out_channels, + self.no_trans, self.group_size, + self.part_size, self.sample_per_part, + self.trans_std) class DeformRoIPoolingPack(DeformRoIPooling): @@ -73,10 +143,11 @@ class DeformRoIPoolingPack(DeformRoIPooling): assert data.size(1) == self.out_channels if self.no_trans: offset = data.new_empty(0) - return deform_roi_pooling( - data, rois, offset, self.spatial_scale, self.out_size, - self.out_channels, self.no_trans, self.group_size, - self.part_size, self.sample_per_part, self.trans_std) + return deform_roi_pooling(data, rois, offset, self.spatial_scale, + self.out_size, self.out_channels, + self.no_trans, self.group_size, + self.part_size, self.sample_per_part, + self.trans_std) else: n = rois.shape[0] offset = data.new_empty(0) @@ -86,10 +157,11 @@ class DeformRoIPoolingPack(DeformRoIPooling): self.sample_per_part, self.trans_std) offset = self.offset_fc(x.view(n, -1)) offset = offset.view(n, 2, self.out_size, self.out_size) - return deform_roi_pooling( - data, rois, offset, self.spatial_scale, self.out_size, - self.out_channels, self.no_trans, self.group_size, - self.part_size, self.sample_per_part, self.trans_std) + return deform_roi_pooling(data, rois, offset, self.spatial_scale, + self.out_size, self.out_channels, + self.no_trans, self.group_size, + self.part_size, self.sample_per_part, + self.trans_std) class ModulatedDeformRoIPoolingPack(DeformRoIPooling): @@ -106,9 +178,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): num_offset_fcs=3, num_mask_fcs=2, deform_fc_channels=1024): - super(ModulatedDeformRoIPoolingPack, self).__init__( - spatial_scale, out_size, out_channels, no_trans, group_size, - part_size, sample_per_part, trans_std) + super(ModulatedDeformRoIPoolingPack, + self).__init__(spatial_scale, out_size, out_channels, no_trans, + group_size, part_size, sample_per_part, trans_std) self.num_offset_fcs = num_offset_fcs self.num_mask_fcs = num_mask_fcs @@ -151,10 +223,11 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): assert data.size(1) == self.out_channels if self.no_trans: offset = data.new_empty(0) - return deform_roi_pooling( - data, rois, offset, self.spatial_scale, self.out_size, - self.out_channels, self.no_trans, self.group_size, - self.part_size, self.sample_per_part, self.trans_std) + return deform_roi_pooling(data, rois, offset, self.spatial_scale, + self.out_size, self.out_channels, + self.no_trans, self.group_size, + self.part_size, self.sample_per_part, + self.trans_std) else: n = rois.shape[0] offset = data.new_empty(0) diff --git a/mmdet/ops/dcn/functions/__init__.py b/mmdet/ops/dcn/functions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/dcn/functions/deform_pool.py b/mmdet/ops/dcn/functions/deform_pool.py deleted file mode 100644 index 65ff0efb5737e87ccf49387b2d24abcbeedd6497..0000000000000000000000000000000000000000 --- a/mmdet/ops/dcn/functions/deform_pool.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -from torch.autograd import Function - -from .. import deform_pool_cuda - - -class DeformRoIPoolingFunction(Function): - - @staticmethod - def forward(ctx, - data, - rois, - offset, - spatial_scale, - out_size, - out_channels, - no_trans, - group_size=1, - part_size=None, - sample_per_part=4, - trans_std=.0): - ctx.spatial_scale = spatial_scale - ctx.out_size = out_size - ctx.out_channels = out_channels - ctx.no_trans = no_trans - ctx.group_size = group_size - ctx.part_size = out_size if part_size is None else part_size - ctx.sample_per_part = sample_per_part - ctx.trans_std = trans_std - - assert 0.0 <= ctx.trans_std <= 1.0 - if not data.is_cuda: - raise NotImplementedError - - n = rois.shape[0] - output = data.new_empty(n, out_channels, out_size, out_size) - output_count = data.new_empty(n, out_channels, out_size, out_size) - deform_pool_cuda.deform_psroi_pooling_cuda_forward( - data, rois, offset, output, output_count, ctx.no_trans, - ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, - ctx.part_size, ctx.sample_per_part, ctx.trans_std) - - if data.requires_grad or rois.requires_grad or offset.requires_grad: - ctx.save_for_backward(data, rois, offset) - ctx.output_count = output_count - - return output - - @staticmethod - def backward(ctx, grad_output): - if not grad_output.is_cuda: - raise NotImplementedError - - data, rois, offset = ctx.saved_tensors - output_count = ctx.output_count - grad_input = torch.zeros_like(data) - grad_rois = None - grad_offset = torch.zeros_like(offset) - - deform_pool_cuda.deform_psroi_pooling_cuda_backward( - grad_output, data, rois, offset, output_count, grad_input, - grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, - ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, - ctx.trans_std) - return (grad_input, grad_rois, grad_offset, None, None, None, None, - None, None, None, None) - - -deform_roi_pooling = DeformRoIPoolingFunction.apply diff --git a/mmdet/ops/dcn/modules/__init__.py b/mmdet/ops/dcn/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/dcn/modules/deform_conv.py b/mmdet/ops/dcn/modules/deform_conv.py deleted file mode 100644 index 50d15d1513f0ebc145982e04958f76a5f1ca1343..0000000000000000000000000000000000000000 --- a/mmdet/ops/dcn/modules/deform_conv.py +++ /dev/null @@ -1,157 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn.modules.utils import _pair - -from ..functions.deform_conv import deform_conv, modulated_deform_conv - - -class DeformConv(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - deformable_groups=1, - bias=False): - super(DeformConv, self).__init__() - - assert not bias - assert in_channels % groups == 0, \ - 'in_channels {} cannot be divisible by groups {}'.format( - in_channels, groups) - assert out_channels % groups == 0, \ - 'out_channels {} cannot be divisible by groups {}'.format( - out_channels, groups) - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) - self.padding = _pair(padding) - self.dilation = _pair(dilation) - self.groups = groups - self.deformable_groups = deformable_groups - - self.weight = nn.Parameter( - torch.Tensor(out_channels, in_channels // self.groups, - *self.kernel_size)) - - self.reset_parameters() - - def reset_parameters(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - - def forward(self, x, offset): - return deform_conv(x, offset, self.weight, self.stride, self.padding, - self.dilation, self.groups, self.deformable_groups) - - -class DeformConvPack(DeformConv): - - def __init__(self, *args, **kwargs): - super(DeformConvPack, self).__init__(*args, **kwargs) - - self.conv_offset = nn.Conv2d( - self.in_channels, - self.deformable_groups * 2 * self.kernel_size[0] * - self.kernel_size[1], - kernel_size=self.kernel_size, - stride=_pair(self.stride), - padding=_pair(self.padding), - bias=True) - self.init_offset() - - def init_offset(self): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x): - offset = self.conv_offset(x) - return deform_conv(x, offset, self.weight, self.stride, self.padding, - self.dilation, self.groups, self.deformable_groups) - - -class ModulatedDeformConv(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - deformable_groups=1, - bias=True): - super(ModulatedDeformConv, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.deformable_groups = deformable_groups - self.with_bias = bias - - self.weight = nn.Parameter( - torch.Tensor(out_channels, in_channels // groups, - *self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.reset_parameters() - - def reset_parameters(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - - def forward(self, x, offset, mask): - return modulated_deform_conv(x, offset, mask, self.weight, self.bias, - self.stride, self.padding, self.dilation, - self.groups, self.deformable_groups) - - -class ModulatedDeformConvPack(ModulatedDeformConv): - - def __init__(self, *args, **kwargs): - super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) - - self.conv_offset_mask = nn.Conv2d( - self.in_channels, - self.deformable_groups * 3 * self.kernel_size[0] * - self.kernel_size[1], - kernel_size=self.kernel_size, - stride=_pair(self.stride), - padding=_pair(self.padding), - bias=True) - self.init_offset() - - def init_offset(self): - self.conv_offset_mask.weight.data.zero_() - self.conv_offset_mask.bias.data.zero_() - - def forward(self, x): - out = self.conv_offset_mask(x) - o1, o2, mask = torch.chunk(out, 3, dim=1) - offset = torch.cat((o1, o2), dim=1) - mask = torch.sigmoid(mask) - return modulated_deform_conv(x, offset, mask, self.weight, self.bias, - self.stride, self.padding, self.dilation, - self.groups, self.deformable_groups) diff --git a/mmdet/ops/gcb/__init__.py b/mmdet/ops/gcb/__init__.py deleted file mode 100644 index 05dd6251e757a1d976c1d59e69d26cb050f38a36..0000000000000000000000000000000000000000 --- a/mmdet/ops/gcb/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .context_block import ContextBlock - -__all__ = [ - 'ContextBlock', -] diff --git a/mmdet/ops/masked_conv/__init__.py b/mmdet/ops/masked_conv/__init__.py index feab953163d704fe59cd4a348056258b55f201f2..6267190019d28da514d1d61c116af915e62697ab 100644 --- a/mmdet/ops/masked_conv/__init__.py +++ b/mmdet/ops/masked_conv/__init__.py @@ -1,4 +1,3 @@ -from .functions.masked_conv import masked_conv2d -from .modules.masked_conv import MaskedConv2d +from .masked_conv import masked_conv2d, MaskedConv2d __all__ = ['masked_conv2d', 'MaskedConv2d'] diff --git a/mmdet/ops/masked_conv/functions/__init__.py b/mmdet/ops/masked_conv/functions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/masked_conv/functions/masked_conv.py b/mmdet/ops/masked_conv/masked_conv.py similarity index 68% rename from mmdet/ops/masked_conv/functions/masked_conv.py rename to mmdet/ops/masked_conv/masked_conv.py index eed32b7374b55ef7e8137525ffc73faf72a7dca9..7d84f503c72e160663ee5677b4e234c618d3449a 100644 --- a/mmdet/ops/masked_conv/functions/masked_conv.py +++ b/mmdet/ops/masked_conv/masked_conv.py @@ -1,8 +1,12 @@ import math + import torch +import torch.nn as nn from torch.autograd import Function +from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from .. import masked_conv2d_cuda + +from . import masked_conv2d_cuda class MaskedConv2dFunction(Function): @@ -49,8 +53,37 @@ class MaskedConv2dFunction(Function): return output @staticmethod + @once_differentiable def backward(ctx, grad_output): return (None, ) * 5 masked_conv2d = MaskedConv2dFunction.apply + + +class MaskedConv2d(nn.Conv2d): + """A MaskedConv2d which inherits the official Conv2d. + + The masked forward doesn't implement the backward function and only + supports the stride parameter to be 1 currently. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super(MaskedConv2d, + self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, input, mask=None): + if mask is None: # fallback to the normal Conv2d + return super(MaskedConv2d, self).forward(input) + else: + return masked_conv2d(input, mask, self.weight, self.bias, + self.padding) diff --git a/mmdet/ops/masked_conv/modules/__init__.py b/mmdet/ops/masked_conv/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/masked_conv/modules/masked_conv.py b/mmdet/ops/masked_conv/modules/masked_conv.py deleted file mode 100644 index 1b8c434a3fc0933d15e28a615f1584e0f907307b..0000000000000000000000000000000000000000 --- a/mmdet/ops/masked_conv/modules/masked_conv.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch.nn as nn -from ..functions.masked_conv import masked_conv2d - - -class MaskedConv2d(nn.Conv2d): - """A MaskedConv2d which inherits the official Conv2d. - - The masked forward doesn't implement the backward function and only - supports the stride parameter to be 1 currently. - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True): - super(MaskedConv2d, - self).__init__(in_channels, out_channels, kernel_size, stride, - padding, dilation, groups, bias) - - def forward(self, input, mask=None): - if mask is None: # fallback to the normal Conv2d - return super(MaskedConv2d, self).forward(input) - else: - return masked_conv2d(input, mask, self.weight, self.bias, - self.padding) diff --git a/mmdet/ops/roi_align/__init__.py b/mmdet/ops/roi_align/__init__.py index 4cb037904a24e613c4b15305cdf8ded6c0072a1b..9acb61fe8a1f54b961fe765b240a21c7ccbd3adf 100644 --- a/mmdet/ops/roi_align/__init__.py +++ b/mmdet/ops/roi_align/__init__.py @@ -1,4 +1,3 @@ -from .functions.roi_align import roi_align -from .modules.roi_align import RoIAlign +from .roi_align import roi_align, RoIAlign __all__ = ['roi_align', 'RoIAlign'] diff --git a/mmdet/ops/roi_align/functions/__init__.py b/mmdet/ops/roi_align/functions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/roi_align/modules/__init__.py b/mmdet/ops/roi_align/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/roi_align/modules/roi_align.py b/mmdet/ops/roi_align/modules/roi_align.py deleted file mode 100644 index de987bd456c88a093632a96b0fcc57b2a3190e87..0000000000000000000000000000000000000000 --- a/mmdet/ops/roi_align/modules/roi_align.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch.nn as nn -from torch.nn.modules.utils import _pair - -from ..functions.roi_align import roi_align - - -class RoIAlign(nn.Module): - - def __init__(self, - out_size, - spatial_scale, - sample_num=0, - use_torchvision=False): - super(RoIAlign, self).__init__() - - self.out_size = out_size - self.spatial_scale = float(spatial_scale) - self.sample_num = int(sample_num) - self.use_torchvision = use_torchvision - - def forward(self, features, rois): - if self.use_torchvision: - from torchvision.ops import roi_align as tv_roi_align - return tv_roi_align(features, rois, _pair(self.out_size), - self.spatial_scale, self.sample_num) - else: - return roi_align(features, rois, self.out_size, self.spatial_scale, - self.sample_num) diff --git a/mmdet/ops/roi_align/functions/roi_align.py b/mmdet/ops/roi_align/roi_align.py similarity index 58% rename from mmdet/ops/roi_align/functions/roi_align.py rename to mmdet/ops/roi_align/roi_align.py index cd2ee9edd10896f2b1728b28b1d47d4883176980..a1fd3641213300050faf81d8033c7f40124266a7 100644 --- a/mmdet/ops/roi_align/functions/roi_align.py +++ b/mmdet/ops/roi_align/roi_align.py @@ -1,7 +1,9 @@ +import torch.nn as nn from torch.autograd import Function +from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from .. import roi_align_cuda +from . import roi_align_cuda class RoIAlignFunction(Function): @@ -28,6 +30,7 @@ class RoIAlignFunction(Function): return output @staticmethod + @once_differentiable def backward(ctx, grad_output): feature_size = ctx.feature_size spatial_scale = ctx.spatial_scale @@ -51,3 +54,34 @@ class RoIAlignFunction(Function): roi_align = RoIAlignFunction.apply + + +class RoIAlign(nn.Module): + + def __init__(self, + out_size, + spatial_scale, + sample_num=0, + use_torchvision=False): + super(RoIAlign, self).__init__() + + self.out_size = out_size + self.spatial_scale = float(spatial_scale) + self.sample_num = int(sample_num) + self.use_torchvision = use_torchvision + + def forward(self, features, rois): + if self.use_torchvision: + from torchvision.ops import roi_align as tv_roi_align + return tv_roi_align(features, rois, _pair(self.out_size), + self.spatial_scale, self.sample_num) + else: + return roi_align(features, rois, self.out_size, self.spatial_scale, + self.sample_num) + + def __repr__(self): + format_str = self.__class__.__name__ + format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format( + self.out_size, self.spatial_scale, self.sample_num) + format_str += ', use_torchvision={})'.format(self.use_torchvision) + return format_str diff --git a/mmdet/ops/roi_pool/__init__.py b/mmdet/ops/roi_pool/__init__.py index eb2c57eabd6fa002c970c1f8d199d80d0a9b689c..d19d60891d4c137b3b6ff4cab452a922c1c9a216 100644 --- a/mmdet/ops/roi_pool/__init__.py +++ b/mmdet/ops/roi_pool/__init__.py @@ -1,4 +1,3 @@ -from .functions.roi_pool import roi_pool -from .modules.roi_pool import RoIPool +from .roi_pool import roi_pool, RoIPool __all__ = ['roi_pool', 'RoIPool'] diff --git a/mmdet/ops/roi_pool/functions/__init__.py b/mmdet/ops/roi_pool/functions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/roi_pool/modules/__init__.py b/mmdet/ops/roi_pool/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/roi_pool/modules/roi_pool.py b/mmdet/ops/roi_pool/modules/roi_pool.py deleted file mode 100644 index c173cbbfd7e9c4f8a7f5cdedc4258fa7e2ccbad2..0000000000000000000000000000000000000000 --- a/mmdet/ops/roi_pool/modules/roi_pool.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch.nn as nn -from torch.nn.modules.utils import _pair - -from ..functions.roi_pool import roi_pool - - -class RoIPool(nn.Module): - - def __init__(self, out_size, spatial_scale, use_torchvision=False): - super(RoIPool, self).__init__() - - self.out_size = out_size - self.spatial_scale = float(spatial_scale) - self.use_torchvision = use_torchvision - - def forward(self, features, rois): - if self.use_torchvision: - from torchvision.ops import roi_pool as tv_roi_pool - return tv_roi_pool(features, rois, _pair(self.out_size), - self.spatial_scale) - else: - return roi_pool(features, rois, self.out_size, self.spatial_scale) diff --git a/mmdet/ops/roi_pool/functions/roi_pool.py b/mmdet/ops/roi_pool/roi_pool.py similarity index 59% rename from mmdet/ops/roi_pool/functions/roi_pool.py rename to mmdet/ops/roi_pool/roi_pool.py index 6de40088c62828f917937a12d9ed2708ce2b85c3..981e81d4e2afc7dc82ff714ccd3d285a8e6e2b22 100644 --- a/mmdet/ops/roi_pool/functions/roi_pool.py +++ b/mmdet/ops/roi_pool/roi_pool.py @@ -1,8 +1,10 @@ import torch +import torch.nn as nn from torch.autograd import Function +from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from .. import roi_pool_cuda +from . import roi_pool_cuda class RoIPoolFunction(Function): @@ -27,6 +29,7 @@ class RoIPoolFunction(Function): return output @staticmethod + @once_differentiable def backward(ctx, grad_output): assert grad_output.is_cuda spatial_scale = ctx.spatial_scale @@ -45,3 +48,28 @@ class RoIPoolFunction(Function): roi_pool = RoIPoolFunction.apply + + +class RoIPool(nn.Module): + + def __init__(self, out_size, spatial_scale, use_torchvision=False): + super(RoIPool, self).__init__() + + self.out_size = out_size + self.spatial_scale = float(spatial_scale) + self.use_torchvision = use_torchvision + + def forward(self, features, rois): + if self.use_torchvision: + from torchvision.ops import roi_pool as tv_roi_pool + return tv_roi_pool(features, rois, _pair(self.out_size), + self.spatial_scale) + else: + return roi_pool(features, rois, self.out_size, self.spatial_scale) + + def __repr__(self): + format_str = self.__class__.__name__ + format_str += '(out_size={}, spatial_scale={}'.format( + self.out_size, self.spatial_scale) + format_str += ', use_torchvision={})'.format(self.use_torchvision) + return format_str diff --git a/mmdet/ops/sigmoid_focal_loss/__init__.py b/mmdet/ops/sigmoid_focal_loss/__init__.py index d0e5abd9e3787d807c871ef4d16dedba4b94ee28..a936cb360d005edb9cbbdce18aa450d11ccf877a 100644 --- a/mmdet/ops/sigmoid_focal_loss/__init__.py +++ b/mmdet/ops/sigmoid_focal_loss/__init__.py @@ -1,3 +1,3 @@ -from .modules.sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss +from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss __all__ = ['SigmoidFocalLoss', 'sigmoid_focal_loss'] diff --git a/mmdet/ops/sigmoid_focal_loss/functions/__init__.py b/mmdet/ops/sigmoid_focal_loss/functions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/sigmoid_focal_loss/modules/__init__.py b/mmdet/ops/sigmoid_focal_loss/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py deleted file mode 100644 index 34202b566437a4d4c6fee5b0cf70f630cac29b3f..0000000000000000000000000000000000000000 --- a/mmdet/ops/sigmoid_focal_loss/modules/sigmoid_focal_loss.py +++ /dev/null @@ -1,24 +0,0 @@ -from torch import nn - -from ..functions.sigmoid_focal_loss import sigmoid_focal_loss - - -# TODO: remove this module -class SigmoidFocalLoss(nn.Module): - - def __init__(self, gamma, alpha): - super(SigmoidFocalLoss, self).__init__() - self.gamma = gamma - self.alpha = alpha - - def forward(self, logits, targets): - assert logits.is_cuda - loss = sigmoid_focal_loss(logits, targets, self.gamma, self.alpha) - return loss.sum() - - def __repr__(self): - tmpstr = self.__class__.__name__ + "(" - tmpstr += "gamma=" + str(self.gamma) - tmpstr += ", alpha=" + str(self.alpha) - tmpstr += ")" - return tmpstr diff --git a/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py b/mmdet/ops/sigmoid_focal_loss/sigmoid_focal_loss.py similarity index 63% rename from mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py rename to mmdet/ops/sigmoid_focal_loss/sigmoid_focal_loss.py index e690f76305ba6a8a6aeb946664e7ddc7f72fcf6a..8298f433f7eb1d6dc86cb4a91aa25d1a47cdf57e 100644 --- a/mmdet/ops/sigmoid_focal_loss/functions/sigmoid_focal_loss.py +++ b/mmdet/ops/sigmoid_focal_loss/sigmoid_focal_loss.py @@ -1,7 +1,8 @@ +import torch.nn as nn from torch.autograd import Function from torch.autograd.function import once_differentiable -from .. import sigmoid_focal_loss_cuda +from . import sigmoid_focal_loss_cuda class SigmoidFocalLossFunction(Function): @@ -32,3 +33,22 @@ class SigmoidFocalLossFunction(Function): sigmoid_focal_loss = SigmoidFocalLossFunction.apply + + +# TODO: remove this module +class SigmoidFocalLoss(nn.Module): + + def __init__(self, gamma, alpha): + super(SigmoidFocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + + def forward(self, logits, targets): + assert logits.is_cuda + loss = sigmoid_focal_loss(logits, targets, self.gamma, self.alpha) + return loss.sum() + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(gamma={}, alpha={})'.format( + self.gamma, self.alpha) + return tmpstr