From 6fe5ccde85d170f52e3fdaa50a9a288e93443fa5 Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Fri, 26 Apr 2019 21:36:55 -0700
Subject: [PATCH] Add support for weight standardization (#521)

* add support for weight standardization

* add ws support for htc heads

* add a config file for 20-23-24e lr schedule

* update readme of weight standardization
---
 README.md                                     |   1 +
 configs/gn+ws/README.md                       |  54 ++++++
 configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py | 165 ++++++++++++++++
 .../mask_rcnn_r50_fpn_gn_ws_20_23_24e.py      | 180 +++++++++++++++++
 configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py   | 180 +++++++++++++++++
 .../mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py      | 182 ++++++++++++++++++
 mmdet/models/backbones/resnet.py              |  71 ++++---
 mmdet/models/backbones/resnext.py             |  24 ++-
 mmdet/models/bbox_heads/convfc_bbox_head.py   |   3 +
 mmdet/models/mask_heads/fcn_mask_head.py      |   3 +
 .../models/mask_heads/fused_semantic_head.py  |   5 +
 mmdet/models/mask_heads/htc_mask_head.py      |   1 +
 mmdet/models/necks/fpn.py                     |   3 +
 mmdet/models/utils/__init__.py                |   8 +-
 mmdet/models/utils/conv_module.py             |  41 +++-
 mmdet/models/utils/conv_ws.py                 |  46 +++++
 16 files changed, 935 insertions(+), 32 deletions(-)
 create mode 100644 configs/gn+ws/README.md
 create mode 100644 configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py
 create mode 100644 configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py
 create mode 100644 configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py
 create mode 100644 configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py
 create mode 100644 mmdet/models/utils/conv_ws.py

diff --git a/README.md b/README.md
index a6b40e4..bffd59a 100644
--- a/README.md
+++ b/README.md
@@ -96,6 +96,7 @@ Results and models are available in the [Model zoo](MODEL_ZOO.md).
 Other features
 - [x] DCNv2
 - [x] Group Normalization
+- [x] Weight Standardization
 - [x] OHEM
 - [x] Soft-NMS
 
diff --git a/configs/gn+ws/README.md b/configs/gn+ws/README.md
new file mode 100644
index 0000000..511f22c
--- /dev/null
+++ b/configs/gn+ws/README.md
@@ -0,0 +1,54 @@
+# Weight Standardization
+
+## Introduction
+
+```
+@article{weightstandardization,
+  author    = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
+  title     = {Weight Standardization},
+  journal   = {arXiv preprint arXiv:1903.10520},
+  year      = {2019},
+}
+```
+
+## Results and Models
+
+Faster R-CNN
+
+| Backbone  | Style   | Normalization | Lr schd | box AP | mask AP | Download |
+|:---------:|:-------:|:-------------:|:-------:|:------:|:-------:|:--------:|
+| R-50-FPN  | pytorch | GN            | 1x      | 37.8   | -       | - |
+| R-50-FPN  | pytorch | GN+WS         | 1x      | 38.9   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/faster_rcnn_r50_fpn_gn_ws_1x_20190418-935d00b6.pth) |
+| R-101-FPN | pytorch | GN            | 1x      | 39.8   | -       | - |
+| R-101-FPN | pytorch | GN+WS         | 1x      | 41.4   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/faster_rcnn_r101_fpn_gn_ws_1x_20190419-728705ec.pth) |
+| X-50-32x4d-FPN | pytorch | GN       | 1x      | 36.5   | -       | - |
+| X-50-32x4d-FPN | pytorch | GN+WS    | 1x      | 39.9   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/faster_rcnn_x50_32x4d_fpn_gn_ws_1x_20190419-4e61072b.pth) |
+| X-101-32x4d-FPN | pytorch | GN      | 1x      | 33.2   | -       | - |
+| X-101-32x4d-FPN | pytorch | GN+WS   | 1x      | 41.8   | -       | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/faster_rcnn_x101_32x4d_fpn_gn_ws_1x_20190419-c78e5583.pth) |
+
+Mask R-CNN
+
+| Backbone  | Style   | Normalization | Lr schd | box AP | mask AP | Download |
+|:---------:|:-------:|:-------------:|:-------:|:------:|:-------:|:--------:|
+| R-50-FPN  | pytorch | GN            | 2x      | 39.9   | 36.0    | - |
+| R-50-FPN  | pytorch | GN+WS         | 2x      | 40.3   | 36.2    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_r50_fpn_gn_ws_2x_20190419-9ec97bbb.pth) |
+| R-101-FPN | pytorch | GN            | 2x      | 41.6   | 37.3    | - |
+| R-101-FPN | pytorch | GN+WS         | 2x      | 42.0   | 37.3    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_r101_fpn_gn_ws_2x_20190419-bc7399a6.pth) |
+| X-50-32x4d-FPN | pytorch | GN       | 2x      | 39.2   | 35.5    | - |
+| X-50-32x4d-FPN | pytorch | GN+WS    | 2x      | 40.7   | 36.7    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_x50_32x4d_fpn_gn_ws_2x_20190419-2110205e.pth) |
+| X-101-32x4d-FPN | pytorch | GN      | 2x      | 36.4   | 33.1    | - |
+| X-101-32x4d-FPN | pytorch | GN+WS   | 2x      | 42.1   | 37.7    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x_20190419-7777b15f.pth) |
+| R-50-FPN  | pytorch | GN            | 20-23-24e | 40.6   | 36.6    | - |
+| R-50-FPN  | pytorch | GN+WS         | 20-23-24e | 41.1   | 37.0    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e_20190425-1d9e499e.pth) |
+| R-101-FPN | pytorch | GN            | 20-23-24e | 42.3   | 38.1    | - |
+| R-101-FPN | pytorch | GN+WS         | 20-23-24e | 43.0   | 38.4    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_r101_fpn_gn_ws_20_23_24e_20190425-66cb3792.pth) |
+| X-50-32x4d-FPN | pytorch | GN       | 20-23-24e | 39.6   | 35.9    | - |
+| X-50-32x4d-FPN | pytorch | GN+WS    | 20-23-24e | 41.9   | 37.7    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_x50_32x4d_fpn_gn_ws_20_23_24e_20190425-d01e2200.pth) |
+| X-101-32x4d-FPN | pytorch | GN      | 20-23-24e | 36.6   | 33.4    | - |
+| X-101-32x4d-FPN | pytorch | GN+WS   | 20-23-24e | 43.4   | 38.7    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ws/mask_rcnn_x101_32x4d_fpn_gn_ws_20_23_24e_20190425-1ff3e5b2.pth) |
+
+Note:
+
+- GN+WS requires about 5% more memory than GN, and it is only 5% slower than GN.
+- In the paper, a 20-23-24e lr schedule is used instead of 2x.
+- The X-50-GN and X-101-GN pretrained models are also shared by the authors.
\ No newline at end of file
diff --git a/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py
new file mode 100644
index 0000000..d494043
--- /dev/null
+++ b/configs/gn+ws/faster_rcnn_r50_fpn_gn_ws_1x.py
@@ -0,0 +1,165 @@
+# model settings
+conv_cfg = dict(type='ConvWS')
+normalize = dict(type='GN', num_groups=32, frozen=False)
+model = dict(
+    type='FasterRCNN',
+    pretrained='open-mmlab://jhu/resnet50_gn_ws',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5,
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        use_sigmoid_cls=True),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='ConvFCBBoxHead',
+        num_shared_convs=4,
+        num_shared_fcs=1,
+        in_channels=256,
+        conv_out_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        conv_cfg=conv_cfg,
+        normalize=normalize))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        smoothl1_beta=1 / 9.0,
+        debug=False),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_r50_fpn_gn_ws_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py
new file mode 100644
index 0000000..2d98767
--- /dev/null
+++ b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_20_23_24e.py
@@ -0,0 +1,180 @@
+# model settings
+conv_cfg = dict(type='ConvWS')
+normalize = dict(type='GN', num_groups=32, frozen=False)
+model = dict(
+    type='MaskRCNN',
+    pretrained='open-mmlab://jhu/resnet50_gn_ws',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5,
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        use_sigmoid_cls=True),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='ConvFCBBoxHead',
+        num_shared_convs=4,
+        num_shared_fcs=1,
+        in_channels=256,
+        conv_out_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        conv_cfg=conv_cfg,
+        normalize=normalize))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        smoothl1_beta=1 / 9.0,
+        debug=False),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[20, 23])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 24
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/mask_rcnn_r50_fpn_gn_ws_20_23_24e'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py
new file mode 100644
index 0000000..c28c6ed
--- /dev/null
+++ b/configs/gn+ws/mask_rcnn_r50_fpn_gn_ws_2x.py
@@ -0,0 +1,180 @@
+# model settings
+conv_cfg = dict(type='ConvWS')
+normalize = dict(type='GN', num_groups=32, frozen=False)
+model = dict(
+    type='MaskRCNN',
+    pretrained='open-mmlab://jhu/resnet50_gn_ws',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5,
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        use_sigmoid_cls=True),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='ConvFCBBoxHead',
+        num_shared_convs=4,
+        num_shared_fcs=1,
+        in_channels=256,
+        conv_out_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        conv_cfg=conv_cfg,
+        normalize=normalize))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        smoothl1_beta=1 / 9.0,
+        debug=False),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[16, 22])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 24
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/mask_rcnn_r50_fpn_gn_ws_2x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py b/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py
new file mode 100644
index 0000000..8fdeaa0
--- /dev/null
+++ b/configs/gn+ws/mask_rcnn_x101_32x4d_fpn_gn_ws_2x.py
@@ -0,0 +1,182 @@
+# model settings
+conv_cfg = dict(type='ConvWS')
+normalize = dict(type='GN', num_groups=32, frozen=False)
+model = dict(
+    type='MaskRCNN',
+    pretrained='open-mmlab://jhu/resnext101_32x4d_gn_ws',
+    backbone=dict(
+        type='ResNeXt',
+        depth=101,
+        groups=32,
+        base_width=4,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5,
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        use_sigmoid_cls=True),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='ConvFCBBoxHead',
+        num_shared_convs=4,
+        num_shared_fcs=1,
+        in_channels=256,
+        conv_out_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        conv_cfg=conv_cfg,
+        normalize=normalize),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        conv_cfg=conv_cfg,
+        normalize=normalize))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        smoothl1_beta=1 / 9.0,
+        debug=False),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[16, 22])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 24
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/mask_rcnn_x101_32x4d_fpn_gn_ws_2x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index 1965d6e..ba7e5fe 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -8,19 +8,7 @@ from mmcv.runner import load_checkpoint
 
 from mmdet.ops import DeformConv, ModulatedDeformConv
 from ..registry import BACKBONES
-from ..utils import build_norm_layer
-
-
-def conv3x3(in_planes, out_planes, stride=1, dilation=1):
-    "3x3 convolution with padding"
-    return nn.Conv2d(
-        in_planes,
-        out_planes,
-        kernel_size=3,
-        stride=stride,
-        padding=dilation,
-        dilation=dilation,
-        bias=False)
+from ..utils import build_conv_layer, build_norm_layer
 
 
 class BasicBlock(nn.Module):
@@ -34,6 +22,7 @@ class BasicBlock(nn.Module):
                  downsample=None,
                  style='pytorch',
                  with_cp=False,
+                 conv_cfg=None,
                  normalize=dict(type='BN'),
                  dcn=None):
         super(BasicBlock, self).__init__()
@@ -42,9 +31,25 @@ class BasicBlock(nn.Module):
         self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
         self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
 
-        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+        self.conv1 = build_conv_layer(
+            conv_cfg,
+            inplanes,
+            planes,
+            3,
+            stride=stride,
+            padding=dilation,
+            dilation=dilation,
+            bias=False)
         self.add_module(self.norm1_name, norm1)
-        self.conv2 = conv3x3(planes, planes)
+        self.conv2 = build_conv_layer(
+            conv_cfg,
+            planes,
+            planes,
+            3,
+            stride=stride,
+            padding=dilation,
+            dilation=dilation,
+            bias=False)
         self.add_module(self.norm2_name, norm2)
 
         self.relu = nn.ReLU(inplace=True)
@@ -91,6 +96,7 @@ class Bottleneck(nn.Module):
                  downsample=None,
                  style='pytorch',
                  with_cp=False,
+                 conv_cfg=None,
                  normalize=dict(type='BN'),
                  dcn=None):
         """Bottleneck block for ResNet.
@@ -102,6 +108,7 @@ class Bottleneck(nn.Module):
         assert dcn is None or isinstance(dcn, dict)
         self.inplanes = inplanes
         self.planes = planes
+        self.conv_cfg = conv_cfg
         self.normalize = normalize
         self.dcn = dcn
         self.with_dcn = dcn is not None
@@ -117,7 +124,8 @@ class Bottleneck(nn.Module):
         self.norm3_name, norm3 = build_norm_layer(
             normalize, planes * self.expansion, postfix=3)
 
-        self.conv1 = nn.Conv2d(
+        self.conv1 = build_conv_layer(
+            conv_cfg,
             inplanes,
             planes,
             kernel_size=1,
@@ -130,7 +138,8 @@ class Bottleneck(nn.Module):
             fallback_on_stride = dcn.get('fallback_on_stride', False)
             self.with_modulated_dcn = dcn.get('modulated', False)
         if not self.with_dcn or fallback_on_stride:
-            self.conv2 = nn.Conv2d(
+            self.conv2 = build_conv_layer(
+                conv_cfg,
                 planes,
                 planes,
                 kernel_size=3,
@@ -139,6 +148,7 @@ class Bottleneck(nn.Module):
                 dilation=dilation,
                 bias=False)
         else:
+            assert conv_cfg is None, 'conv_cfg must be None for DCN'
             deformable_groups = dcn.get('deformable_groups', 1)
             if not self.with_modulated_dcn:
                 conv_op = DeformConv
@@ -163,8 +173,12 @@ class Bottleneck(nn.Module):
                 deformable_groups=deformable_groups,
                 bias=False)
         self.add_module(self.norm2_name, norm2)
-        self.conv3 = nn.Conv2d(
-            planes, planes * self.expansion, kernel_size=1, bias=False)
+        self.conv3 = build_conv_layer(
+            conv_cfg,
+            planes,
+            planes * self.expansion,
+            kernel_size=1,
+            bias=False)
         self.add_module(self.norm3_name, norm3)
 
         self.relu = nn.ReLU(inplace=True)
@@ -236,12 +250,14 @@ def make_res_layer(block,
                    dilation=1,
                    style='pytorch',
                    with_cp=False,
+                   conv_cfg=None,
                    normalize=dict(type='BN'),
                    dcn=None):
     downsample = None
     if stride != 1 or inplanes != planes * block.expansion:
         downsample = nn.Sequential(
-            nn.Conv2d(
+            build_conv_layer(
+                conv_cfg,
                 inplanes,
                 planes * block.expansion,
                 kernel_size=1,
@@ -260,6 +276,7 @@ def make_res_layer(block,
             downsample,
             style=style,
             with_cp=with_cp,
+            conv_cfg=conv_cfg,
             normalize=normalize,
             dcn=dcn))
     inplanes = planes * block.expansion
@@ -272,6 +289,7 @@ def make_res_layer(block,
                 dilation,
                 style=style,
                 with_cp=with_cp,
+                conv_cfg=conv_cfg,
                 normalize=normalize,
                 dcn=dcn))
 
@@ -319,6 +337,7 @@ class ResNet(nn.Module):
                  out_indices=(0, 1, 2, 3),
                  style='pytorch',
                  frozen_stages=-1,
+                 conv_cfg=None,
                  normalize=dict(type='BN', frozen=False),
                  norm_eval=True,
                  dcn=None,
@@ -338,6 +357,7 @@ class ResNet(nn.Module):
         assert max(out_indices) < num_stages
         self.style = style
         self.frozen_stages = frozen_stages
+        self.conv_cfg = conv_cfg
         self.normalize = normalize
         self.with_cp = with_cp
         self.norm_eval = norm_eval
@@ -367,6 +387,7 @@ class ResNet(nn.Module):
                 dilation=dilation,
                 style=self.style,
                 with_cp=with_cp,
+                conv_cfg=conv_cfg,
                 normalize=normalize,
                 dcn=dcn)
             self.inplanes = planes * self.block.expansion
@@ -384,8 +405,14 @@ class ResNet(nn.Module):
         return getattr(self, self.norm1_name)
 
     def _make_stem_layer(self):
-        self.conv1 = nn.Conv2d(
-            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.conv1 = build_conv_layer(
+            self.conv_cfg,
+            3,
+            64,
+            kernel_size=7,
+            stride=2,
+            padding=3,
+            bias=False)
         self.norm1_name, norm1 = build_norm_layer(
             self.normalize, 64, postfix=1)
         self.add_module(self.norm1_name, norm1)
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
index 0e83d07..3cdb1cb 100644
--- a/mmdet/models/backbones/resnext.py
+++ b/mmdet/models/backbones/resnext.py
@@ -6,7 +6,7 @@ from mmdet.ops import DeformConv, ModulatedDeformConv
 from .resnet import Bottleneck as _Bottleneck
 from .resnet import ResNet
 from ..registry import BACKBONES
-from ..utils import build_norm_layer
+from ..utils import build_conv_layer, build_norm_layer
 
 
 class Bottleneck(_Bottleneck):
@@ -30,7 +30,8 @@ class Bottleneck(_Bottleneck):
         self.norm3_name, norm3 = build_norm_layer(
             self.normalize, self.planes * self.expansion, postfix=3)
 
-        self.conv1 = nn.Conv2d(
+        self.conv1 = build_conv_layer(
+            self.conv_cfg,
             self.inplanes,
             width,
             kernel_size=1,
@@ -43,7 +44,8 @@ class Bottleneck(_Bottleneck):
             fallback_on_stride = self.dcn.get('fallback_on_stride', False)
             self.with_modulated_dcn = self.dcn.get('modulated', False)
         if not self.with_dcn or fallback_on_stride:
-            self.conv2 = nn.Conv2d(
+            self.conv2 = build_conv_layer(
+                self.conv_cfg,
                 width,
                 width,
                 kernel_size=3,
@@ -53,6 +55,7 @@ class Bottleneck(_Bottleneck):
                 groups=groups,
                 bias=False)
         else:
+            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
             groups = self.dcn.get('groups', 1)
             deformable_groups = self.dcn.get('deformable_groups', 1)
             if not self.with_modulated_dcn:
@@ -79,8 +82,12 @@ class Bottleneck(_Bottleneck):
                 deformable_groups=deformable_groups,
                 bias=False)
         self.add_module(self.norm2_name, norm2)
-        self.conv3 = nn.Conv2d(
-            width, self.planes * self.expansion, kernel_size=1, bias=False)
+        self.conv3 = build_conv_layer(
+            self.conv_cfg,
+            width,
+            self.planes * self.expansion,
+            kernel_size=1,
+            bias=False)
         self.add_module(self.norm3_name, norm3)
 
 
@@ -94,12 +101,14 @@ def make_res_layer(block,
                    base_width=4,
                    style='pytorch',
                    with_cp=False,
+                   conv_cfg=None,
                    normalize=dict(type='BN'),
                    dcn=None):
     downsample = None
     if stride != 1 or inplanes != planes * block.expansion:
         downsample = nn.Sequential(
-            nn.Conv2d(
+            build_conv_layer(
+                conv_cfg,
                 inplanes,
                 planes * block.expansion,
                 kernel_size=1,
@@ -120,6 +129,7 @@ def make_res_layer(block,
             base_width=base_width,
             style=style,
             with_cp=with_cp,
+            conv_cfg=conv_cfg,
             normalize=normalize,
             dcn=dcn))
     inplanes = planes * block.expansion
@@ -134,6 +144,7 @@ def make_res_layer(block,
                 base_width=base_width,
                 style=style,
                 with_cp=with_cp,
+                conv_cfg=conv_cfg,
                 normalize=normalize,
                 dcn=dcn))
 
@@ -196,6 +207,7 @@ class ResNeXt(ResNet):
                 base_width=self.base_width,
                 style=self.style,
                 with_cp=self.with_cp,
+                conv_cfg=self.conv_cfg,
                 normalize=self.normalize,
                 dcn=dcn)
             self.inplanes = planes * self.block.expansion
diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py
index c424aff..af2a5e3 100644
--- a/mmdet/models/bbox_heads/convfc_bbox_head.py
+++ b/mmdet/models/bbox_heads/convfc_bbox_head.py
@@ -24,6 +24,7 @@ class ConvFCBBoxHead(BBoxHead):
                  num_reg_fcs=0,
                  conv_out_channels=256,
                  fc_out_channels=1024,
+                 conv_cfg=None,
                  normalize=None,
                  *args,
                  **kwargs):
@@ -44,6 +45,7 @@ class ConvFCBBoxHead(BBoxHead):
         self.num_reg_fcs = num_reg_fcs
         self.conv_out_channels = conv_out_channels
         self.fc_out_channels = fc_out_channels
+        self.conv_cfg = conv_cfg
         self.normalize = normalize
         self.with_bias = normalize is None
 
@@ -101,6 +103,7 @@ class ConvFCBBoxHead(BBoxHead):
                         self.conv_out_channels,
                         3,
                         padding=1,
+                        conv_cfg=self.conv_cfg,
                         normalize=self.normalize,
                         bias=self.with_bias))
             last_layer_dim = self.conv_out_channels
diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py
index e1889f0..614497a 100644
--- a/mmdet/models/mask_heads/fcn_mask_head.py
+++ b/mmdet/models/mask_heads/fcn_mask_head.py
@@ -22,6 +22,7 @@ class FCNMaskHead(nn.Module):
                  upsample_ratio=2,
                  num_classes=81,
                  class_agnostic=False,
+                 conv_cfg=None,
                  normalize=None):
         super(FCNMaskHead, self).__init__()
         if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
@@ -37,6 +38,7 @@ class FCNMaskHead(nn.Module):
         self.upsample_ratio = upsample_ratio
         self.num_classes = num_classes
         self.class_agnostic = class_agnostic
+        self.conv_cfg = conv_cfg
         self.normalize = normalize
         self.with_bias = normalize is None
 
@@ -51,6 +53,7 @@ class FCNMaskHead(nn.Module):
                     self.conv_out_channels,
                     self.conv_kernel_size,
                     padding=padding,
+                    conv_cfg=conv_cfg,
                     normalize=normalize,
                     bias=self.with_bias))
         upsample_in_channels = (self.conv_out_channels
diff --git a/mmdet/models/mask_heads/fused_semantic_head.py b/mmdet/models/mask_heads/fused_semantic_head.py
index 8795f66..f24adf3 100644
--- a/mmdet/models/mask_heads/fused_semantic_head.py
+++ b/mmdet/models/mask_heads/fused_semantic_head.py
@@ -30,6 +30,7 @@ class FusedSemanticHead(nn.Module):
                  num_classes=183,
                  ignore_label=255,
                  loss_weight=0.2,
+                 conv_cfg=None,
                  normalize=None):
         super(FusedSemanticHead, self).__init__()
         self.num_ins = num_ins
@@ -40,6 +41,7 @@ class FusedSemanticHead(nn.Module):
         self.num_classes = num_classes
         self.ignore_label = ignore_label
         self.loss_weight = loss_weight
+        self.conv_cfg = conv_cfg
         self.normalize = normalize
         self.with_bias = normalize is None
 
@@ -50,6 +52,7 @@ class FusedSemanticHead(nn.Module):
                     self.in_channels,
                     self.in_channels,
                     1,
+                    conv_cfg=self.conv_cfg,
                     normalize=self.normalize,
                     bias=self.with_bias,
                     inplace=False))
@@ -63,12 +66,14 @@ class FusedSemanticHead(nn.Module):
                     conv_out_channels,
                     3,
                     padding=1,
+                    conv_cfg=self.conv_cfg,
                     normalize=self.normalize,
                     bias=self.with_bias))
         self.conv_embedding = ConvModule(
             conv_out_channels,
             conv_out_channels,
             1,
+            conv_cfg=self.conv_cfg,
             normalize=self.normalize,
             bias=self.with_bias)
         self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
diff --git a/mmdet/models/mask_heads/htc_mask_head.py b/mmdet/models/mask_heads/htc_mask_head.py
index 6f1ccf0..21f3130 100644
--- a/mmdet/models/mask_heads/htc_mask_head.py
+++ b/mmdet/models/mask_heads/htc_mask_head.py
@@ -12,6 +12,7 @@ class HTCMaskHead(FCNMaskHead):
             self.conv_out_channels,
             self.conv_out_channels,
             1,
+            conv_cfg=self.conv_cfg,
             normalize=self.normalize,
             bias=self.with_bias)
 
diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
index 90c0f23..3e49fc4 100644
--- a/mmdet/models/necks/fpn.py
+++ b/mmdet/models/necks/fpn.py
@@ -17,6 +17,7 @@ class FPN(nn.Module):
                  end_level=-1,
                  add_extra_convs=False,
                  extra_convs_on_inputs=True,
+                 conv_cfg=None,
                  normalize=None,
                  activation=None):
         super(FPN, self).__init__()
@@ -49,6 +50,7 @@ class FPN(nn.Module):
                 in_channels[i],
                 out_channels,
                 1,
+                conv_cfg=conv_cfg,
                 normalize=normalize,
                 bias=self.with_bias,
                 activation=self.activation,
@@ -58,6 +60,7 @@ class FPN(nn.Module):
                 out_channels,
                 3,
                 padding=1,
+                conv_cfg=conv_cfg,
                 normalize=normalize,
                 bias=self.with_bias,
                 activation=self.activation,
diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py
index 194eb9c..c517b07 100644
--- a/mmdet/models/utils/__init__.py
+++ b/mmdet/models/utils/__init__.py
@@ -1,9 +1,11 @@
-from .conv_module import ConvModule
+from .conv_ws import conv_ws_2d, ConvWS2d
+from .conv_module import build_conv_layer, ConvModule
 from .norm import build_norm_layer
 from .weight_init import (xavier_init, normal_init, uniform_init, kaiming_init,
                           bias_init_with_prob)
 
 __all__ = [
-    'ConvModule', 'build_norm_layer', 'xavier_init', 'normal_init',
-    'uniform_init', 'kaiming_init', 'bias_init_with_prob'
+    'conv_ws_2d', 'ConvWS2d', 'build_conv_layer', 'ConvModule',
+    'build_norm_layer', 'xavier_init', 'normal_init', 'uniform_init',
+    'kaiming_init', 'bias_init_with_prob'
 ]
diff --git a/mmdet/models/utils/conv_module.py b/mmdet/models/utils/conv_module.py
index b5651ec..b3bf9c7 100644
--- a/mmdet/models/utils/conv_module.py
+++ b/mmdet/models/utils/conv_module.py
@@ -3,8 +3,43 @@ import warnings
 import torch.nn as nn
 from mmcv.cnn import kaiming_init, constant_init
 
+from .conv_ws import ConvWS2d
 from .norm import build_norm_layer
 
+conv_cfg = {
+    'Conv': nn.Conv2d,
+    'ConvWS': ConvWS2d,
+    # TODO: octave conv
+}
+
+
+def build_conv_layer(cfg, *args, **kwargs):
+    """ Build convolution layer
+
+    Args:
+        cfg (None or dict): cfg should contain:
+            type (str): identify conv layer type.
+            layer args: args needed to instantiate a conv layer.
+
+    Returns:
+        layer (nn.Module): created conv layer
+    """
+    if cfg is None:
+        cfg_ = dict(type='Conv')
+    else:
+        assert isinstance(cfg, dict) and 'type' in cfg
+        cfg_ = cfg.copy()
+
+    layer_type = cfg_.pop('type')
+    if layer_type not in conv_cfg:
+        raise KeyError('Unrecognized norm type {}'.format(layer_type))
+    else:
+        conv_layer = conv_cfg[layer_type]
+
+    layer = conv_layer(*args, **kwargs, **cfg_)
+
+    return layer
+
 
 class ConvModule(nn.Module):
 
@@ -17,11 +52,14 @@ class ConvModule(nn.Module):
                  dilation=1,
                  groups=1,
                  bias=True,
+                 conv_cfg=None,
                  normalize=None,
                  activation='relu',
                  inplace=True,
                  activate_last=True):
         super(ConvModule, self).__init__()
+        assert conv_cfg is None or isinstance(conv_cfg, dict)
+        assert normalize is None or isinstance(normalize, dict)
         self.with_norm = normalize is not None
         self.with_activatation = activation is not None
         self.with_bias = bias
@@ -31,7 +69,8 @@ class ConvModule(nn.Module):
         if self.with_norm and self.with_bias:
             warnings.warn('ConvModule has norm and bias at the same time')
 
-        self.conv = nn.Conv2d(
+        self.conv = build_conv_layer(
+            conv_cfg,
             in_channels,
             out_channels,
             kernel_size,
diff --git a/mmdet/models/utils/conv_ws.py b/mmdet/models/utils/conv_ws.py
new file mode 100644
index 0000000..5ccd735
--- /dev/null
+++ b/mmdet/models/utils/conv_ws.py
@@ -0,0 +1,46 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_ws_2d(input,
+               weight,
+               bias=None,
+               stride=1,
+               padding=0,
+               dilation=1,
+               groups=1,
+               eps=1e-5):
+    c_in = weight.size(0)
+    weight_flat = weight.view(c_in, -1)
+    mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+    std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+    weight = (weight - mean) / (std + eps)
+    return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+
+
+class ConvWS2d(nn.Conv2d):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 bias=True,
+                 eps=1e-5):
+        super(ConvWS2d, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias)
+        self.eps = eps
+
+    def forward(self, x):
+        return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
+                          self.dilation, self.groups, self.eps)
-- 
GitLab