diff --git a/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py b/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py index c27fff20cd53d578e716f26d477d62ed25362bcd..40a2d97d06fcf27e5ea929490f0a67f381b0f42d 100644 --- a/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py +++ b/configs/dcn/cascade_mask_rcnn_dconv_c3-c5_r50_fpn_1x.py @@ -11,7 +11,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=False, deformable_groups=1, fallback_on_stride=False), + type='DCN', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True)), neck=dict( type='FPN', diff --git a/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py b/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py index 2a4740b392bcd07cc8cd9c2204d0734811fc44ec..9d4402a852d440809c0aadc551e7c74edbe59bdf 100644 --- a/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py +++ b/configs/dcn/cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py @@ -11,7 +11,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=False, deformable_groups=1, fallback_on_stride=False), + type='DCN', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True)), neck=dict( type='FPN', diff --git a/configs/dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x.py b/configs/dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x.py index 11c7dd35e6aee4503de4e4bacb8013434dad8c6f..83967ff92efc512ea7281796117fd884cdc1a7ca 100644 --- a/configs/dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x.py +++ b/configs/dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x.py @@ -10,7 +10,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=False, deformable_groups=1, fallback_on_stride=False), + type='DCN', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True)), neck=dict( type='FPN', diff --git a/configs/dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x.py b/configs/dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x.py index 9156b0d8ba6f509c593b6ce4613c2566cb958a93..06aa6eeaf77703482125a7790c56650ad73fb49e 100644 --- a/configs/dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x.py +++ b/configs/dcn/faster_rcnn_dconv_c3-c5_x101_32x4d_fpn_1x.py @@ -12,8 +12,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=False, - groups=32, + type='DCN', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True)), diff --git a/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py b/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py index 4e91bb0b05b03eabad02acb6365349323125bfe2..e469223ca9035d62c9646cf03063f20d9d581efc 100644 --- a/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py +++ b/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py @@ -10,7 +10,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=True, deformable_groups=4, fallback_on_stride=False), + type='DCNv2', deformable_groups=4, fallback_on_stride=False), stage_with_dcn=(False, True, True, True)), neck=dict( type='FPN', diff --git a/configs/dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x.py b/configs/dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x.py index 484b4aff10782debbace72a218bf538f9785fefd..2e79fa85dc6a304fd962b753adfa585e208afcee 100644 --- a/configs/dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x.py +++ b/configs/dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x.py @@ -10,7 +10,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=True, deformable_groups=1, fallback_on_stride=False), + type='DCNv2', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True)), neck=dict( type='FPN', diff --git a/configs/dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x.py b/configs/dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x.py index c3de699a1c4aff0551830aa4e034edddeb15fd84..8ed0af0de89760f0bff6d5bf2df15419035cdd6f 100644 --- a/configs/dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x.py +++ b/configs/dcn/mask_rcnn_dconv_c3-c5_r50_fpn_1x.py @@ -10,7 +10,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=False, deformable_groups=1, fallback_on_stride=False), + type='DCN', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True)), neck=dict( type='FPN', diff --git a/configs/dcn/mask_rcnn_mdconv_c3-c5_r50_fpn_1x.py b/configs/dcn/mask_rcnn_mdconv_c3-c5_r50_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..c4902567ba9412e3eb5f7cf1364b04aa0d23f861 --- /dev/null +++ b/configs/dcn/mask_rcnn_mdconv_c3-c5_r50_fpn_1x.py @@ -0,0 +1,190 @@ +# model settings +model = dict( + type='MaskRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch', + dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False), + stage_with_dcn=(False, True, True, True)), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=81, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100, + mask_thr_binary=0.5)) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/mask_rcnn_dconv_c3-c5_r50_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py index d264b9bf4ad1f18c533ea7fe35b4ef27117a95e8..7d0843e560062a609483141218d3373ca99c7dbf 100644 --- a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py +++ b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py @@ -13,7 +13,7 @@ model = dict( spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2), stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]], dcn=dict( - modulated=False, deformable_groups=1, fallback_on_stride=False), + type='DCN', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True), ), neck=dict( diff --git a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py index f39edee846eaeb5e886f06b88141da7d96e17ed0..f603e21d67b93a8ce6ed64ffa5690e2701dd129d 100644 --- a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py +++ b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py @@ -13,7 +13,7 @@ model = dict( spatial_range=-1, num_heads=8, attention_type='1111', kv_stride=2), stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]], dcn=dict( - modulated=False, deformable_groups=1, fallback_on_stride=False), + type='DCN', deformable_groups=1, fallback_on_stride=False), stage_with_dcn=(False, True, True, True), ), neck=dict( diff --git a/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py b/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py index 275b7fbb590d0272651f247723a0484ee87642e7..b9139b3723c65df737a27fe610d5bb1c26882d9f 100644 --- a/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py +++ b/configs/htc/htc_dconv_c3-c5_mstrain_400_1400_x101_64x4d_fpn_20e.py @@ -15,7 +15,7 @@ model = dict( frozen_stages=1, style='pytorch', dcn=dict( - modulated=False, + type='DCN', groups=64, deformable_groups=1, fallback_on_stride=False), diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index 3343c5c504e39b72dcd118bbcc513148e39a7e5b..d2f3b7c517eae3c0c6efeb9d7e4ab65389c840aa 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -5,7 +5,7 @@ from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from mmdet.models.plugins import GeneralizedAttention -from mmdet.ops import ContextBlock, DeformConv, ModulatedDeformConv +from mmdet.ops import ContextBlock from ..registry import BACKBONES from ..utils import build_conv_layer, build_norm_layer @@ -143,10 +143,8 @@ class Bottleneck(nn.Module): bias=False) self.add_module(self.norm1_name, norm1) fallback_on_stride = False - self.with_modulated_dcn = False if self.with_dcn: - fallback_on_stride = dcn.get('fallback_on_stride', False) - self.with_modulated_dcn = dcn.get('modulated', False) + fallback_on_stride = dcn.pop('fallback_on_stride', False) if not self.with_dcn or fallback_on_stride: self.conv2 = build_conv_layer( conv_cfg, @@ -158,30 +156,17 @@ class Bottleneck(nn.Module): dilation=dilation, bias=False) else: - assert conv_cfg is None, 'conv_cfg must be None for DCN' - self.deformable_groups = dcn.get('deformable_groups', 1) - if not self.with_modulated_dcn: - conv_op = DeformConv - offset_channels = 18 - else: - conv_op = ModulatedDeformConv - offset_channels = 27 - self.conv2_offset = nn.Conv2d( - planes, - self.deformable_groups * offset_channels, - kernel_size=3, - stride=self.conv2_stride, - padding=dilation, - dilation=dilation) - self.conv2 = conv_op( + assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN' + self.conv2 = build_conv_layer( + dcn, planes, planes, kernel_size=3, stride=self.conv2_stride, padding=dilation, dilation=dilation, - deformable_groups=self.deformable_groups, bias=False) + self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( conv_cfg, @@ -224,17 +209,7 @@ class Bottleneck(nn.Module): out = self.norm1(out) out = self.relu(out) - if not self.with_dcn: - out = self.conv2(out) - elif self.with_modulated_dcn: - offset_mask = self.conv2_offset(out) - offset = offset_mask[:, :18 * self.deformable_groups, :, :] - mask = offset_mask[:, -9 * self.deformable_groups:, :, :] - mask = mask.sigmoid() - out = self.conv2(out, offset, mask) - else: - offset = self.conv2_offset(out) - out = self.conv2(out, offset) + out = self.conv2(out) out = self.norm2(out) out = self.relu(out) diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py index 38afee28b6f99307e2497203b93b7dba317648b6..0c184abb6ad6bae6f73b8a9636b26ba746f9acb2 100644 --- a/mmdet/models/backbones/resnext.py +++ b/mmdet/models/backbones/resnext.py @@ -2,7 +2,6 @@ import math import torch.nn as nn -from mmdet.ops import DeformConv, ModulatedDeformConv from ..registry import BACKBONES from ..utils import build_conv_layer, build_norm_layer from .resnet import Bottleneck as _Bottleneck @@ -41,8 +40,7 @@ class Bottleneck(_Bottleneck): fallback_on_stride = False self.with_modulated_dcn = False if self.with_dcn: - fallback_on_stride = self.dcn.get('fallback_on_stride', False) - self.with_modulated_dcn = self.dcn.get('modulated', False) + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) if not self.with_dcn or fallback_on_stride: self.conv2 = build_conv_layer( self.conv_cfg, @@ -56,22 +54,8 @@ class Bottleneck(_Bottleneck): bias=False) else: assert self.conv_cfg is None, 'conv_cfg must be None for DCN' - groups = self.dcn.get('groups', 1) - deformable_groups = self.dcn.get('deformable_groups', 1) - if not self.with_modulated_dcn: - conv_op = DeformConv - offset_channels = 18 - else: - conv_op = ModulatedDeformConv - offset_channels = 27 - self.conv2_offset = nn.Conv2d( - width, - deformable_groups * offset_channels, - kernel_size=3, - stride=self.conv2_stride, - padding=self.dilation, - dilation=self.dilation) - self.conv2 = conv_op( + self.conv2 = build_conv_layer( + self.dcn, width, width, kernel_size=3, @@ -79,8 +63,8 @@ class Bottleneck(_Bottleneck): padding=self.dilation, dilation=self.dilation, groups=groups, - deformable_groups=deformable_groups, bias=False) + self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( self.conv_cfg, diff --git a/mmdet/models/utils/conv_module.py b/mmdet/models/utils/conv_module.py index f2bddc445b5eb5217695b734cf8a6d4006915033..a1771f8cdcd9a08e932543da2cdc99f2c36997cc 100644 --- a/mmdet/models/utils/conv_module.py +++ b/mmdet/models/utils/conv_module.py @@ -3,12 +3,15 @@ import warnings import torch.nn as nn from mmcv.cnn import constant_init, kaiming_init +from mmdet.ops import DeformConvPack, ModulatedDeformConvPack from .conv_ws import ConvWS2d from .norm import build_norm_layer conv_cfg = { 'Conv': nn.Conv2d, 'ConvWS': ConvWS2d, + 'DCN': DeformConvPack, + 'DCNv2': ModulatedDeformConvPack, # TODO: octave conv } diff --git a/mmdet/ops/dcn/deform_conv.py b/mmdet/ops/dcn/deform_conv.py index 7f6841a585e504c8a480d5ac134807ae63edcd5d..24e5b08cbfd2a9581fae59fe56ed419ea9441e31 100644 --- a/mmdet/ops/dcn/deform_conv.py +++ b/mmdet/ops/dcn/deform_conv.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from torch.autograd import Function from torch.autograd.function import once_differentiable -from torch.nn.modules.utils import _pair +from torch.nn.modules.utils import _pair, _single from . import deform_conv_cuda @@ -24,7 +24,7 @@ class DeformConvFunction(Function): im2col_step=64): if input is not None and input.dim() != 4: raise ValueError( - "Expected 4D tensor as input, got {}D tensor instead.".format( + 'Expected 4D tensor as input, got {}D tensor instead.'.format( input.dim())) ctx.stride = _pair(stride) ctx.padding = _pair(padding) @@ -105,7 +105,7 @@ class DeformConvFunction(Function): output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) if not all(map(lambda s: s > 0, output_size)): raise ValueError( - "convolution input is too small (output would be {})".format( + 'convolution input is too small (output would be {})'.format( 'x'.join(map(str, output_size)))) return output_size @@ -217,6 +217,9 @@ class DeformConv(nn.Module): self.dilation = _pair(dilation) self.groups = groups self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels // self.groups, @@ -237,6 +240,22 @@ class DeformConv(nn.Module): class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 def __init__(self, *args, **kwargs): super(DeformConvPack, self).__init__(*args, **kwargs) @@ -260,6 +279,33 @@ class DeformConvPack(DeformConv): return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups) + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, DeformConvPack loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + from mmdet.apis import get_root_logger + logger = get_root_logger() + logger.info('DeformConvPack {} is upgraded to version 2.'.format( + prefix.rstrip('.'))) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) + class ModulatedDeformConv(nn.Module): @@ -283,6 +329,9 @@ class ModulatedDeformConv(nn.Module): self.groups = groups self.deformable_groups = deformable_groups self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels // groups, @@ -309,11 +358,27 @@ class ModulatedDeformConv(nn.Module): class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 def __init__(self, *args, **kwargs): super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) - self.conv_offset_mask = nn.Conv2d( + self.conv_offset = nn.Conv2d( self.in_channels, self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], @@ -324,14 +389,43 @@ class ModulatedDeformConvPack(ModulatedDeformConv): self.init_offset() def init_offset(self): - self.conv_offset_mask.weight.data.zero_() - self.conv_offset_mask.bias.data.zero_() + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() def forward(self, x): - out = self.conv_offset_mask(x) + out = self.conv_offset(x) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, ModulatedDeformConvPack + # loads previous benchmark models. + if (prefix + 'conv_offset.weight' not in state_dict + and prefix[:-1] + '_offset.weight' in state_dict): + state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( + prefix[:-1] + '_offset.weight') + if (prefix + 'conv_offset.bias' not in state_dict + and prefix[:-1] + '_offset.bias' in state_dict): + state_dict[prefix + + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + + '_offset.bias') + + if version is not None and version > 1: + from mmdet.apis import get_root_logger + logger = get_root_logger() + logger.info( + 'ModulatedDeformConvPack {} is upgraded to version 2.'.format( + prefix.rstrip('.'))) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) diff --git a/requirements.txt b/requirements.txt index 5cacde1c99f1c0db65011225a8524491e223775c..01cfcdedb6fb4037b9306e389a8469aac799a8fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ albumentations>=0.3.2 imagecorruptions matplotlib -mmcv>=0.2.15 +mmcv>=0.2.16 numpy pycocotools six