From 40ea9ca23a1cbe65dbe800dc065acfe44f5626f8 Mon Sep 17 00:00:00 2001
From: Chongruo Wu <chongruo@gmail.com>
Date: Thu, 12 Sep 2019 00:05:40 +0800
Subject: [PATCH] [Modulated DeformableConv] fix a bug for deformable_group>1 
 (#1359)

* fix a bug for Modulated Deformable Convolution

* resolve cases for CI test

* resolve cases for CI test

* resolve cases for CI test

* create a new script

* change dir name in config

* update model url
---
 configs/dcn/README.md                         |   2 +
 ...ter_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py | 177 ++++++++++++++++++
 mmdet/models/backbones/resnet.py              |  11 +-
 3 files changed, 185 insertions(+), 5 deletions(-)
 create mode 100644 configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py

diff --git a/configs/dcn/README.md b/configs/dcn/README.md
index d95ee40..12f003f 100644
--- a/configs/dcn/README.md
+++ b/configs/dcn/README.md
@@ -24,6 +24,7 @@
 |:---------:|:------------:|:-------:|:-------------:|:------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:-------:|:--------:|
 | R-50-FPN  | Faster       | pytorch | dconv(c3-c5)  | -      | 1x      | 3.9      | 0.594               | 10.2           | 40.0   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_r50_fpn_1x_20190125-e41688c9.pth) |
 | R-50-FPN  | Faster       | pytorch | mdconv(c3-c5) | -      | 1x      | 3.7      | 0.598               | 10.0           | 40.2   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_mdconv_c3-c5_r50_fpn_1x_20190125-1b768045.pth) |
+| *R-50-FPN (dg=4) | Faster | pytorch | mdconv(c3-c5) | -     | 1x      | 3.9      | 0.350               | 11.36           | 40.4   | -       | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x_20190911-5591a7e4.pth) |
 | R-50-FPN  | Faster       | pytorch | -             | dpool  | 1x      | 4.6      | 0.714               | 8.7            | 37.8   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dpool_r50_fpn_1x_20190125-f4fc1d70.pth) |
 | R-50-FPN  | Faster       | pytorch | -             | mdpool | 1x      | 5.2      | 0.769               | 8.2            | 38.0   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_mdpool_r50_fpn_1x_20190125-473d0f3d.pth) |
 | R-101-FPN | Faster       | pytorch | dconv(c3-c5)  | -      | 1x      | 5.8      | 0.811               | 8.0            | 42.1   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/dcn/faster_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-a7e31b65.pth) |
@@ -40,4 +41,5 @@
 
 - `dconv` and `mdconv` denote (modulated) deformable convolution, `c3-c5` means adding dconv in resnet stage 3 to 5. `dpool` and `mdpool` denote (modulated) deformable roi pooling.
 - The dcn ops are modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch, which should be more memory efficient and slightly faster.
+- (*) For R-50-FPN (dg=4), dg is short for deformable_group. This model is trained and tested on Amazon EC2 p3dn.24xlarge instance.
 - **Memory, Train/Inf time is outdated.**
\ No newline at end of file
diff --git a/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py b/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py
new file mode 100644
index 0000000..4e91bb0
--- /dev/null
+++ b/configs/dcn/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x.py
@@ -0,0 +1,177 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        dcn=dict(
+            modulated=True, deformable_groups=4, fallback_on_stride=False),
+        stage_with_dcn=(False, True, True, True)),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_mdconv_c3-c5_group4_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index eaead29..d87b736 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -161,7 +161,7 @@ class Bottleneck(nn.Module):
                 bias=False)
         else:
             assert conv_cfg is None, 'conv_cfg must be None for DCN'
-            deformable_groups = dcn.get('deformable_groups', 1)
+            self.deformable_groups = dcn.get('deformable_groups', 1)
             if not self.with_modulated_dcn:
                 conv_op = DeformConv
                 offset_channels = 18
@@ -170,7 +170,7 @@ class Bottleneck(nn.Module):
                 offset_channels = 27
             self.conv2_offset = nn.Conv2d(
                 planes,
-                deformable_groups * offset_channels,
+                self.deformable_groups * offset_channels,
                 kernel_size=3,
                 stride=self.conv2_stride,
                 padding=dilation,
@@ -182,7 +182,7 @@ class Bottleneck(nn.Module):
                 stride=self.conv2_stride,
                 padding=dilation,
                 dilation=dilation,
-                deformable_groups=deformable_groups,
+                deformable_groups=self.deformable_groups,
                 bias=False)
         self.add_module(self.norm2_name, norm2)
         self.conv3 = build_conv_layer(
@@ -230,8 +230,9 @@ class Bottleneck(nn.Module):
                 out = self.conv2(out)
             elif self.with_modulated_dcn:
                 offset_mask = self.conv2_offset(out)
-                offset = offset_mask[:, :18, :, :]
-                mask = offset_mask[:, -9:, :, :].sigmoid()
+                offset = offset_mask[:, :18 * self.deformable_groups, :, :]
+                mask = offset_mask[:, -9 * self.deformable_groups:, :, :]
+                mask = mask.sigmoid()
                 out = self.conv2(out, offset, mask)
             else:
                 offset = self.conv2_offset(out)
-- 
GitLab