diff --git a/configs/empirical_attention/README.md b/configs/empirical_attention/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0c8650a616a8a7d62aa473ece2960a5f22d36cb0
--- /dev/null
+++ b/configs/empirical_attention/README.md
@@ -0,0 +1,23 @@
+# An Empirical Study of Spatial Attention Mechanisms in Deep Networks
+
+## Introduction
+
+```
+@article{zhu2019empirical,
+  title={An Empirical Study of Spatial Attention Mechanisms in Deep Networks},
+  author={Zhu, Xizhou and Cheng, Dazhi and Zhang, Zheng and Lin, Stephen and Dai, Jifeng},
+  journal={arXiv preprint arXiv:1904.05873},
+  year={2019}
+}
+```
+
+
+## Results and Models
+
+| Backbone  | Attention Component | DCN  | Lr schd | box AP | Download |
+|:---------:|:-------------------:|:----:|:-------:|:------:|:--------:|
+| R-50      | 1111                | N    | 1x      | 38.6   |     -    |
+| R-50      | 0010                | N    | 1x      | 38.2   |     -    |
+| R-50      | 1111                | Y    | 1x      | 41.0   |     -    |
+| R-50      | 0010                | Y    | 1x      | 40.8   |     -    |
+
diff --git a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_1x.py b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0bc83101943a8fa7c2fb223dcc226380f6dd34e
--- /dev/null
+++ b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_1x.py
@@ -0,0 +1,171 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        gen_attention=dict(
+            spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2),
+        stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
+    ),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_r50_fpn_attention_0010_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6f586c475ba168ba96e27e11030dd14a0fe46d
--- /dev/null
+++ b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_0010_dcn_1x.py
@@ -0,0 +1,174 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        gen_attention=dict(
+            spatial_range=-1, num_heads=8, attention_type='0010', kv_stride=2),
+        stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
+        dcn=dict(
+            modulated=False, deformable_groups=1, fallback_on_stride=False),
+        stage_with_dcn=(False, True, True, True),
+    ),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_r50_fpn_attention_0010_dcn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_1x.py b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfde7ba6c8396aecb54e0e1d3746ec82a9deb8a2
--- /dev/null
+++ b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_1x.py
@@ -0,0 +1,171 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        gen_attention=dict(
+            spatial_range=-1, num_heads=8, attention_type='1111', kv_stride=2),
+        stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
+    ),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_r50_fpn_attention_1111_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ffd518aeb191877b93767f159732ed99f4fdf2e
--- /dev/null
+++ b/configs/empirical_attention/faster_rcnn_r50_fpn_attention_1111_dcn_1x.py
@@ -0,0 +1,174 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch',
+        gen_attention=dict(
+            spatial_range=-1, num_heads=8, attention_type='1111', kv_stride=2),
+        stage_with_gen_attention=[[], [], [0, 1, 2, 3, 4, 5], [0, 1, 2]],
+        dcn=dict(
+            modulated=False, deformable_groups=1, fallback_on_stride=False),
+        stage_with_dcn=(False, True, True, True),
+    ),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_r50_fpn_attention_1111_dcn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index c018ed634bf6ed04d1a1adea72a22c143adca218..3a8a6131a7b02e52337ca972efb82a3d96b53b5b 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -8,6 +8,8 @@ from mmcv.cnn import constant_init, kaiming_init
 from mmcv.runner import load_checkpoint
 
 from mmdet.ops import DeformConv, ModulatedDeformConv, ContextBlock
+from mmdet.models.plugins import GeneralizedAttention
+
 from ..registry import BACKBONES
 from ..utils import build_conv_layer, build_norm_layer
 
@@ -26,9 +28,11 @@ class BasicBlock(nn.Module):
                  conv_cfg=None,
                  norm_cfg=dict(type='BN'),
                  dcn=None,
-                 gcb=None):
+                 gcb=None,
+                 gen_attention=None):
         super(BasicBlock, self).__init__()
         assert dcn is None, "Not implemented yet."
+        assert gen_attention is None, "Not implemented yet."
         assert gcb is None, "Not implemented yet."
 
         self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
@@ -95,7 +99,8 @@ class Bottleneck(nn.Module):
                  conv_cfg=None,
                  norm_cfg=dict(type='BN'),
                  dcn=None,
-                 gcb=None):
+                 gcb=None,
+                 gen_attention=None):
         """Bottleneck block for ResNet.
         If style is "pytorch", the stride-two layer is the 3x3 conv layer,
         if it is "caffe", the stride-two layer is the first 1x1 conv layer.
@@ -104,6 +109,8 @@ class Bottleneck(nn.Module):
         assert style in ['pytorch', 'caffe']
         assert dcn is None or isinstance(dcn, dict)
         assert gcb is None or isinstance(gcb, dict)
+        assert gen_attention is None or isinstance(gen_attention, dict)
+
         self.inplanes = inplanes
         self.planes = planes
         self.stride = stride
@@ -116,6 +123,9 @@ class Bottleneck(nn.Module):
         self.with_dcn = dcn is not None
         self.gcb = gcb
         self.with_gcb = gcb is not None
+        self.gen_attention = gen_attention
+        self.with_gen_attention = gen_attention is not None
+
         if self.style == 'pytorch':
             self.conv1_stride = 1
             self.conv2_stride = stride
@@ -187,12 +197,15 @@ class Bottleneck(nn.Module):
 
         self.relu = nn.ReLU(inplace=True)
         self.downsample = downsample
+
         if self.with_gcb:
             gcb_inplanes = planes * self.expansion
-            self.context_block = ContextBlock(
-                inplanes=gcb_inplanes,
-                **gcb
-            )
+            self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)
+
+        # gen_attention
+        if self.with_gen_attention:
+            self.gen_attention_block = GeneralizedAttention(
+                planes, **gen_attention)
 
     @property
     def norm1(self):
@@ -228,6 +241,9 @@ class Bottleneck(nn.Module):
             out = self.norm2(out)
             out = self.relu(out)
 
+            if self.with_gen_attention:
+                out = self.gen_attention_block(out)
+
             out = self.conv3(out)
             out = self.norm3(out)
 
@@ -262,7 +278,9 @@ def make_res_layer(block,
                    conv_cfg=None,
                    norm_cfg=dict(type='BN'),
                    dcn=None,
-                   gcb=None):
+                   gcb=None,
+                   gen_attention=None,
+                   gen_attention_blocks=[]):
     downsample = None
     if stride != 1 or inplanes != planes * block.expansion:
         downsample = nn.Sequential(
@@ -289,7 +307,9 @@ def make_res_layer(block,
             conv_cfg=conv_cfg,
             norm_cfg=norm_cfg,
             dcn=dcn,
-            gcb=gcb))
+            gcb=gcb,
+            gen_attention=gen_attention if
+            (0 in gen_attention_blocks) else None))
     inplanes = planes * block.expansion
     for i in range(1, blocks):
         layers.append(
@@ -303,7 +323,9 @@ def make_res_layer(block,
                 conv_cfg=conv_cfg,
                 norm_cfg=norm_cfg,
                 dcn=dcn,
-                gcb=gcb))
+                gcb=gcb,
+                gen_attention=gen_attention if
+                (i in gen_attention_blocks) else None))
 
     return nn.Sequential(*layers)
 
@@ -356,6 +378,8 @@ class ResNet(nn.Module):
                  stage_with_dcn=(False, False, False, False),
                  gcb=None,
                  stage_with_gcb=(False, False, False, False),
+                 gen_attention=None,
+                 stage_with_gen_attention=((), (), (), ()),
                  with_cp=False,
                  zero_init_residual=True):
         super(ResNet, self).__init__()
@@ -379,6 +403,7 @@ class ResNet(nn.Module):
         self.stage_with_dcn = stage_with_dcn
         if dcn is not None:
             assert len(stage_with_dcn) == num_stages
+        self.gen_attention = gen_attention
         self.gcb = gcb
         self.stage_with_gcb = stage_with_gcb
         if gcb is not None:
@@ -409,7 +434,9 @@ class ResNet(nn.Module):
                 conv_cfg=conv_cfg,
                 norm_cfg=norm_cfg,
                 dcn=dcn,
-                gcb=gcb)
+                gcb=gcb,
+                gen_attention=gen_attention,
+                gen_attention_blocks=stage_with_gen_attention[i])
             self.inplanes = planes * self.block.expansion
             layer_name = 'layer{}'.format(i + 1)
             self.add_module(layer_name, res_layer)
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
index ced5c53c693010f9fd6cea6a114bc44f04912273..f68c0e7e10e51e6a44eaf0897d16b93e6f8fcbfc 100644
--- a/mmdet/models/backbones/resnext.py
+++ b/mmdet/models/backbones/resnext.py
@@ -11,7 +11,7 @@ from ..utils import build_conv_layer, build_norm_layer
 
 class Bottleneck(_Bottleneck):
 
-    def __init__(self, *args, groups=1, base_width=4, **kwargs):
+    def __init__(self, groups=1, base_width=4, *args, **kwargs):
         """Bottleneck block for ResNeXt.
         If style is "pytorch", the stride-two layer is the 3x3 conv layer,
         if it is "caffe", the stride-two layer is the first 1x1 conv layer.
diff --git a/mmdet/models/plugins/__init__.py b/mmdet/models/plugins/__init__.py
index 87744df7db561f07794a397c176817908bb58102..2a771b906bcfe71e9b6f53ecf7edb7f6621a0f73 100644
--- a/mmdet/models/plugins/__init__.py
+++ b/mmdet/models/plugins/__init__.py
@@ -1,3 +1,4 @@
 from .non_local import NonLocal2D
+from .generalized_attention import GeneralizedAttention
 
-__all__ = ['NonLocal2D']
+__all__ = ['NonLocal2D', 'GeneralizedAttention']
diff --git a/mmdet/models/plugins/generalized_attention.py b/mmdet/models/plugins/generalized_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..7786837ba9238fc4871b69d70351c71ec92c7312
--- /dev/null
+++ b/mmdet/models/plugins/generalized_attention.py
@@ -0,0 +1,384 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import math
+import numpy as np
+from mmcv.cnn import kaiming_init
+
+
+class GeneralizedAttention(nn.Module):
+    """GeneralizedAttention module.
+
+    See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
+    (https://arxiv.org/abs/1711.07971) for details.
+
+    Args:
+        in_dim (int): Channels of the input feature map.
+        spatial_range (int): The spatial range.
+            -1 indicates no spatial range constraint.
+        num_heads (int): The head number of empirical_attention module.
+        position_embedding_dim (int): The position embedding dimension.
+        position_magnitude (int): A multiplier acting on coord difference.
+        kv_stride (int): The feature stride acting on key/value feature map.
+        q_stride (int): The feature stride acting on query feature map.
+        attention_type (str): A binary indicator string for indicating which
+            items in generalized empirical_attention module are used.
+            '1000' indicates 'query and key content' (appr - appr) item,
+            '0100' indicates 'query content and relative position'
+              (appr - position) item,
+            '0010' indicates 'key content only' (bias - appr) item,
+            '0001' indicates 'relative position only' (bias - position) item.
+    """
+
+    def __init__(self,
+                 in_dim,
+                 spatial_range=-1,
+                 num_heads=9,
+                 position_embedding_dim=-1,
+                 position_magnitude=1,
+                 kv_stride=2,
+                 q_stride=1,
+                 attention_type='1111'):
+
+        super(GeneralizedAttention, self).__init__()
+
+        # hard range means local range for non-local operation
+        self.position_embedding_dim = (
+            position_embedding_dim if position_embedding_dim > 0 else in_dim)
+
+        self.position_magnitude = position_magnitude
+        self.num_heads = num_heads
+        self.channel_in = in_dim
+        self.spatial_range = spatial_range
+        self.kv_stride = kv_stride
+        self.q_stride = q_stride
+        self.attention_type = [bool(int(_)) for _ in attention_type]
+        self.qk_embed_dim = in_dim // num_heads
+        out_c = self.qk_embed_dim * num_heads
+
+        if self.attention_type[0] or self.attention_type[1]:
+            self.query_conv = nn.Conv2d(
+                in_channels=in_dim,
+                out_channels=out_c,
+                kernel_size=1,
+                bias=False)
+            self.query_conv.kaiming_init = True
+
+        if self.attention_type[0] or self.attention_type[2]:
+            self.key_conv = nn.Conv2d(
+                in_channels=in_dim,
+                out_channels=out_c,
+                kernel_size=1,
+                bias=False)
+            self.key_conv.kaiming_init = True
+
+        self.v_dim = in_dim // num_heads
+        self.value_conv = nn.Conv2d(
+            in_channels=in_dim,
+            out_channels=self.v_dim * num_heads,
+            kernel_size=1,
+            bias=False)
+        self.value_conv.kaiming_init = True
+
+        if self.attention_type[1] or self.attention_type[3]:
+            self.appr_geom_fc_x = nn.Linear(
+                self.position_embedding_dim // 2, out_c, bias=False)
+            self.appr_geom_fc_x.kaiming_init = True
+
+            self.appr_geom_fc_y = nn.Linear(
+                self.position_embedding_dim // 2, out_c, bias=False)
+            self.appr_geom_fc_y.kaiming_init = True
+
+        if self.attention_type[2]:
+            stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+            appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+            self.appr_bias = nn.Parameter(appr_bias_value)
+
+        if self.attention_type[3]:
+            stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+            geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+            self.geom_bias = nn.Parameter(geom_bias_value)
+
+        self.proj_conv = nn.Conv2d(
+            in_channels=self.v_dim * num_heads,
+            out_channels=in_dim,
+            kernel_size=1,
+            bias=True)
+        self.proj_conv.kaiming_init = True
+        self.gamma = nn.Parameter(torch.zeros(1))
+
+        if self.spatial_range >= 0:
+            # only works when non local is after 3*3 conv
+            if in_dim == 256:
+                max_len = 84
+            elif in_dim == 512:
+                max_len = 42
+
+            max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
+            local_constraint_map = np.ones(
+                (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
+            for iy in range(max_len):
+                for ix in range(max_len):
+                    local_constraint_map[iy, ix,
+                                         max((iy - self.spatial_range) //
+                                             self.kv_stride, 0):min(
+                                                 (iy + self.spatial_range +
+                                                  1) // self.kv_stride +
+                                                 1, max_len),
+                                         max((ix - self.spatial_range) //
+                                             self.kv_stride, 0):min(
+                                                 (ix + self.spatial_range +
+                                                  1) // self.kv_stride +
+                                                 1, max_len)] = 0
+
+            self.local_constraint_map = nn.Parameter(
+                torch.from_numpy(local_constraint_map).byte(),
+                requires_grad=False)
+
+        if self.q_stride > 1:
+            self.q_downsample = nn.AvgPool2d(
+                kernel_size=1, stride=self.q_stride)
+        else:
+            self.q_downsample = None
+
+        if self.kv_stride > 1:
+            self.kv_downsample = nn.AvgPool2d(
+                kernel_size=1, stride=self.kv_stride)
+        else:
+            self.kv_downsample = None
+
+        self.init_weights()
+
+    def get_position_embedding(self,
+                               h,
+                               w,
+                               h_kv,
+                               w_kv,
+                               q_stride,
+                               kv_stride,
+                               device,
+                               feat_dim,
+                               wave_length=1000):
+        h_idxs = torch.linspace(0, h - 1, h).cuda(device)
+        h_idxs = h_idxs.view((h, 1)) * q_stride
+
+        w_idxs = torch.linspace(0, w - 1, w).cuda(device)
+        w_idxs = w_idxs.view((w, 1)) * q_stride
+
+        h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).cuda(device)
+        h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
+
+        w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).cuda(device)
+        w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
+
+        # (h, h_kv, 1)
+        h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
+        h_diff *= self.position_magnitude
+
+        # (w, w_kv, 1)
+        w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
+        w_diff *= self.position_magnitude
+
+        feat_range = torch.arange(0, feat_dim / 4).cuda(device)
+
+        dim_mat = torch.Tensor([wave_length]).cuda(device)
+        dim_mat = dim_mat**((4. / feat_dim) * feat_range)
+        dim_mat = dim_mat.view((1, 1, -1))
+
+        embedding_x = torch.cat(
+            ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
+
+        embedding_y = torch.cat(
+            ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
+
+        return embedding_x, embedding_y
+
+    def forward(self, x_input):
+        num_heads = self.num_heads
+
+        # use empirical_attention
+        if self.q_downsample is not None:
+            x_q = self.q_downsample(x_input)
+        else:
+            x_q = x_input
+        n, _, h, w = x_q.shape
+
+        if self.kv_downsample is not None:
+            x_kv = self.kv_downsample(x_input)
+        else:
+            x_kv = x_input
+        _, _, h_kv, w_kv = x_kv.shape
+
+        if self.attention_type[0] or self.attention_type[1]:
+            proj_query = self.query_conv(x_q).view(
+                (n, num_heads, self.qk_embed_dim, h * w))
+            proj_query = proj_query.permute(0, 1, 3, 2)
+
+        if self.attention_type[0] or self.attention_type[2]:
+            proj_key = self.key_conv(x_kv).view(
+                (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
+
+        if self.attention_type[1] or self.attention_type[3]:
+            position_embed_x, position_embed_y = self.get_position_embedding(
+                h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
+                x_input.device, self.position_embedding_dim)
+            # (n, num_heads, w, w_kv, dim)
+            position_feat_x = self.appr_geom_fc_x(position_embed_x).\
+                view(1, w, w_kv, num_heads, self.qk_embed_dim).\
+                permute(0, 3, 1, 2, 4).\
+                repeat(n, 1, 1, 1, 1)
+
+            # (n, num_heads, h, h_kv, dim)
+            position_feat_y = self.appr_geom_fc_y(position_embed_y).\
+                view(1, h, h_kv, num_heads, self.qk_embed_dim).\
+                permute(0, 3, 1, 2, 4).\
+                repeat(n, 1, 1, 1, 1)
+
+            position_feat_x /= math.sqrt(2)
+            position_feat_y /= math.sqrt(2)
+
+        # accelerate for saliency only
+        if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
+            appr_bias = self.appr_bias.\
+                view(1, num_heads, 1, self.qk_embed_dim).\
+                repeat(n, 1, 1, 1)
+
+            energy = torch.matmul(appr_bias, proj_key).\
+                view(n, num_heads, 1, h_kv * w_kv)
+
+            h = 1
+            w = 1
+        else:
+            # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
+            if not self.attention_type[0]:
+                energy = torch.zeros(
+                    n,
+                    num_heads,
+                    h,
+                    w,
+                    h_kv,
+                    w_kv,
+                    dtype=x_input.dtype,
+                    device=x_input.device)
+
+            # attention_type[0]: appr - appr
+            # attention_type[1]: appr - position
+            # attention_type[2]: bias - appr
+            # attention_type[3]: bias - position
+            if self.attention_type[0] or self.attention_type[2]:
+                if self.attention_type[0] and self.attention_type[2]:
+                    appr_bias = self.appr_bias.\
+                        view(1, num_heads, 1, self.qk_embed_dim)
+                    energy = torch.matmul(proj_query + appr_bias, proj_key).\
+                        view(n, num_heads, h, w, h_kv, w_kv)
+
+                elif self.attention_type[0]:
+                    energy = torch.matmul(proj_query, proj_key).\
+                        view(n, num_heads, h, w, h_kv, w_kv)
+
+                elif self.attention_type[2]:
+                    appr_bias = self.appr_bias.\
+                        view(1, num_heads, 1, self.qk_embed_dim).\
+                        repeat(n, 1, 1, 1)
+
+                    energy += torch.matmul(appr_bias, proj_key).\
+                        view(n, num_heads, 1, 1, h_kv, w_kv)
+
+            if self.attention_type[1] or self.attention_type[3]:
+                if self.attention_type[1] and self.attention_type[3]:
+                    geom_bias = self.geom_bias.\
+                        view(1, num_heads, 1, self.qk_embed_dim)
+
+                    proj_query_reshape = (proj_query + geom_bias).\
+                        view(n, num_heads, h, w, self.qk_embed_dim)
+
+                    energy_x = torch.matmul(
+                        proj_query_reshape.permute(0, 1, 3, 2, 4),
+                        position_feat_x.permute(0, 1, 2, 4, 3))
+                    energy_x = energy_x.\
+                        permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+                    energy_y = torch.matmul(
+                        proj_query_reshape,
+                        position_feat_y.permute(0, 1, 2, 4, 3))
+                    energy_y = energy_y.unsqueeze(5)
+
+                    energy += energy_x + energy_y
+
+                elif self.attention_type[1]:
+                    proj_query_reshape = proj_query.\
+                        view(n, num_heads, h, w, self.qk_embed_dim)
+                    proj_query_reshape = proj_query_reshape.\
+                        permute(0, 1, 3, 2, 4)
+                    position_feat_x_reshape = position_feat_x.\
+                        permute(0, 1, 2, 4, 3)
+                    position_feat_y_reshape = position_feat_y.\
+                        permute(0, 1, 2, 4, 3)
+
+                    energy_x = torch.matmul(proj_query_reshape,
+                                            position_feat_x_reshape)
+                    energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+                    energy_y = torch.matmul(proj_query_reshape,
+                                            position_feat_y_reshape)
+                    energy_y = energy_y.unsqueeze(5)
+
+                    energy += energy_x + energy_y
+
+                elif self.attention_type[3]:
+                    geom_bias = self.geom_bias.\
+                        view(1, num_heads, self.qk_embed_dim, 1).\
+                        repeat(n, 1, 1, 1)
+
+                    position_feat_x_reshape = position_feat_x.\
+                        view(n, num_heads, w*w_kv, self.qk_embed_dim)
+
+                    position_feat_y_reshape = position_feat_y.\
+                        view(n, num_heads, h * h_kv, self.qk_embed_dim)
+
+                    energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
+                    energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
+
+                    energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
+                    energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
+
+                    energy += energy_x + energy_y
+
+            energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
+
+        if self.spatial_range >= 0:
+            cur_local_constraint_map = \
+                self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
+                contiguous().\
+                view(1, 1, h*w, h_kv*w_kv)
+
+            energy = energy.masked_fill_(cur_local_constraint_map,
+                                         float('-inf'))
+
+        attention = F.softmax(energy, 3)
+
+        proj_value = self.value_conv(x_kv)
+        proj_value_reshape = proj_value.\
+            view((n, num_heads, self.v_dim, h_kv * w_kv)).\
+            permute(0, 1, 3, 2)
+
+        out = torch.matmul(attention, proj_value_reshape).\
+            permute(0, 1, 3, 2).\
+            contiguous().\
+            view(n, self.v_dim * self.num_heads, h, w)
+
+        out = self.proj_conv(out)
+        out = self.gamma * out + x_input
+        return out
+
+    def init_weights(self):
+        for m in self.modules():
+            if hasattr(m, 'kaiming_init') and m.kaiming_init:
+                kaiming_init(
+                    m,
+                    mode='fan_in',
+                    nonlinearity='leaky_relu',
+                    bias=0,
+                    distribution='uniform',
+                    a=1)