From 40ea9ca23a1cbe65dbe800dc065acfe44f5626f8 Mon Sep 17 00:00:00 2001 From: Chongruo Wu <chongruo@gmail.com> Date: Thu, 12 Sep 2019 00:05:40 +0800 Subject: [PATCH] [Modulated DeformableConv] fix a bug for deformable_group>1 (#1359) * fix a bug for Modulated Deformable Convolution * resolve cases for CI test * resolve cases for CI test * resolve cases for CI test * create a new script * change dir name in config * update model url --- configs/dcn/README.md | 2 + ...ter_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py | 177 ++++++++++++++++++ mmdet/models/backbones/resnet.py | 11 +- 3 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py diff --git a/configs/dcn/README.md b/configs/dcn/README.md index d95ee40..12f003f 100644 --- a/configs/dcn/README.md +++ b/configs/dcn/README.md @@ -24,6 +24,7 @@ |:---------:|:------------:|:-------:|:-------------:|:------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:-------:|:--------:| | R-50-FPN | Faster | pytorch | dconv(c3-c5) | - | 1x | 3.9 | 0.594 | 10.2 | 40.0 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x_20190125-e41688c9.pth) | | R-50-FPN | Faster | pytorch | mdconv(c3-c5) | - | 1x | 3.7 | 0.598 | 10.0 | 40.2 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x_20190125-1b768045.pth) | +| *R-50-FPN (dg=4) | Faster | pytorch | mdconv(c3-c5) | - | 1x | 3.9 | 0.350 | 11.36 | 40.4 | - | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x_20190911-5591a7e4.pth) | | R-50-FPN | Faster | pytorch | - | dpool | 1x | 4.6 | 0.714 | 8.7 | 37.8 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dpool_r50_fpn_1x_20190125-f4fc1d70.pth) | | R-50-FPN | Faster | pytorch | - | mdpool | 1x | 5.2 | 0.769 | 8.2 | 38.0 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_mdpool_r50_fpn_1x_20190125-473d0f3d.pth) | | R-101-FPN | Faster | pytorch | dconv(c3-c5) | - | 1x | 5.8 | 0.811 | 8.0 | 42.1 | - | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-a7e31b65.pth) | @@ -40,4 +41,5 @@ - `dconv` and `mdconv` denote (modulated) deformable convolution, `c3-c5` means adding dconv in resnet stage 3 to 5. `dpool` and `mdpool` denote (modulated) deformable roi pooling. - The dcn ops are modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch, which should be more memory efficient and slightly faster. +- (*) For R-50-FPN (dg=4), dg is short for deformable_group. This model is trained and tested on Amazon EC2 p3dn.24xlarge instance. - **Memory, Train/Inf time is outdated.** \ No newline at end of file diff --git a/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py b/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py new file mode 100644 index 0000000..4e91bb0 --- /dev/null +++ b/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py @@ -0,0 +1,177 @@ +# model settings +model = dict( + type='FasterRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch', + dcn=dict( + modulated=True, deformable_groups=4, fallback_on_stride=False), + stage_with_dcn=(False, True, True, True)), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05) +) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py index eaead29..d87b736 100644 --- a/mmdet/models/backbones/resnet.py +++ b/mmdet/models/backbones/resnet.py @@ -161,7 +161,7 @@ class Bottleneck(nn.Module): bias=False) else: assert conv_cfg is None, 'conv_cfg must be None for DCN' - deformable_groups = dcn.get('deformable_groups', 1) + self.deformable_groups = dcn.get('deformable_groups', 1) if not self.with_modulated_dcn: conv_op = DeformConv offset_channels = 18 @@ -170,7 +170,7 @@ class Bottleneck(nn.Module): offset_channels = 27 self.conv2_offset = nn.Conv2d( planes, - deformable_groups * offset_channels, + self.deformable_groups * offset_channels, kernel_size=3, stride=self.conv2_stride, padding=dilation, @@ -182,7 +182,7 @@ class Bottleneck(nn.Module): stride=self.conv2_stride, padding=dilation, dilation=dilation, - deformable_groups=deformable_groups, + deformable_groups=self.deformable_groups, bias=False) self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( @@ -230,8 +230,9 @@ class Bottleneck(nn.Module): out = self.conv2(out) elif self.with_modulated_dcn: offset_mask = self.conv2_offset(out) - offset = offset_mask[:, :18, :, :] - mask = offset_mask[:, -9:, :, :].sigmoid() + offset = offset_mask[:, :18 * self.deformable_groups, :, :] + mask = offset_mask[:, -9 * self.deformable_groups:, :, :] + mask = mask.sigmoid() out = self.conv2(out, offset, mask) else: offset = self.conv2_offset(out) -- GitLab