diff --git a/configs/ms_rcnn/README.md b/configs/ms_rcnn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3b08ad15e7ae833570729021b9de0bda6569b332
--- /dev/null
+++ b/configs/ms_rcnn/README.md
@@ -0,0 +1,23 @@
+# Mask Scoring R-CNN
+
+## Introduction
+
+```
+@inproceedings{huang2019msrcnn,
+    title={Mask Scoring R-CNN},
+    author={Zhaojin Huang and Lichao Huang and Yongchao Gong and Chang Huang and Xinggang Wang},
+    booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+    year={2019},
+}
+```
+
+## Results and Models
+
+| Backbone      | style      | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | mask AP | Download |
+|:-------------:|:----------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:-------:|:--------:|
+| R-50-FPN      | caffe      | 1x      | 4.3      | 0.537               | 10.1           | 37.4   | 35.5    | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/ms-rcnn/ms_rcnn_r50_caffe_fpn_1x_20190624-619934b5.pth) |
+| R-50-FPN      | caffe      | 2x      | -        | -                   | -              | 38.2   | 35.9    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_r50_caffe_fpn_2x_20190525-a07be31e.pth) |
+| R-101-FPN     | caffe      | 1x      | 6.2      | 0.682               |  9.1           | 39.8   | 37.2    | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/ms-rcnn/ms_rcnn_r101_caffe_fpn_1x_20190624-677a5548.pth) |
+| R-101-FPN     | caffe      | 2x      | -        | -                   |  -             | 40.7   | 37.8    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_r101_caffe_fpn_2x_20190525-4aee1528.pth) |
+| R-X101-64x4d  | pytorch    | 1x      | 10.5     | 1.214               |  6.4           | 42.2   | 39.2    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_x101_64x4d_fpn_1x_20190525-026b16ae.pth) |
+| R-X101-64x4d  | pytorch    | 2x      | -       | -                    |  -             | 42.2   | 38.9    | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_x101_64x4d_fpn_2x_20190525-c044c25a.pth) |
diff --git a/configs/ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py b/configs/ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..682d308a937675bc2a03c59982ab853bd7c39b15
--- /dev/null
+++ b/configs/ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py
@@ -0,0 +1,192 @@
+# model settings
+model = dict(
+    type='MaskScoringRCNN',
+    pretrained='open-mmlab://resnet101_caffe',
+    backbone=dict(
+        type='ResNet',
+        depth=101,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        norm_cfg=dict(type='BN', requires_grad=False),
+        style='caffe'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        loss_mask=dict(
+            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
+    mask_iou_head=dict(
+        type='MaskIoUHead',
+        num_convs=4,
+        num_fcs=2,
+        roi_feat_size=14,
+        in_channels=256,
+        conv_out_channels=256,
+        fc_out_channels=1024,
+        num_classes=81))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        mask_thr_binary=0.5,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/ms_rcnn_r101_caffe_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py b/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ebde50bfcc25b79447dc894352a8d48638d40b
--- /dev/null
+++ b/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py
@@ -0,0 +1,192 @@
+# model settings
+model = dict(
+    type='MaskScoringRCNN',
+    pretrained='open-mmlab://resnet50_caffe',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        norm_cfg=dict(type='BN', requires_grad=False),
+        style='caffe'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        loss_mask=dict(
+            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
+    mask_iou_head=dict(
+        type='MaskIoUHead',
+        num_convs=4,
+        num_fcs=2,
+        roi_feat_size=14,
+        in_channels=256,
+        conv_out_channels=256,
+        fc_out_channels=1024,
+        num_classes=81))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        mask_thr_binary=0.5,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/ms_rcnn_r50_caffe_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py b/configs/ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6b720332cf33263295dcfeeae0d85b793e5166d
--- /dev/null
+++ b/configs/ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py
@@ -0,0 +1,193 @@
+# model settings
+model = dict(
+    type='MaskScoringRCNN',
+    pretrained='open-mmlab://resnext101_64x4d',
+    backbone=dict(
+        type='ResNeXt',
+        depth=101,
+        groups=64,
+        base_width=4,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        loss_mask=dict(
+            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
+    mask_iou_head=dict(
+        type='MaskIoUHead',
+        num_convs=4,
+        num_fcs=2,
+        roi_feat_size=14,
+        in_channels=256,
+        conv_out_channels=256,
+        fc_out_channels=1024,
+        num_classes=81))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        mask_thr_binary=0.5,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/ms_rcnn_x101_64x4d_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index 824833d8823d924c75ebe68c3ffca4c80270a981..2424a8bd06cb0b43b522ed98c90d3655205b58ed 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -9,9 +9,10 @@ from .cascade_rcnn import CascadeRCNN
 from .htc import HybridTaskCascade
 from .retinanet import RetinaNet
 from .fcos import FCOS
+from .mask_scoring_rcnn import MaskScoringRCNN
 
 __all__ = [
     'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
     'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
-    'RetinaNet', 'FCOS'
+    'RetinaNet', 'FCOS', 'MaskScoringRCNN'
 ]
diff --git a/mmdet/models/detectors/mask_scoring_rcnn.py b/mmdet/models/detectors/mask_scoring_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..9153bcd147b5c4f9658fb7d09403b78b8ba4e181
--- /dev/null
+++ b/mmdet/models/detectors/mask_scoring_rcnn.py
@@ -0,0 +1,197 @@
+import torch
+
+from mmdet.core import bbox2roi, build_assigner, build_sampler
+from .two_stage import TwoStageDetector
+from .. import builder
+from ..registry import DETECTORS
+
+
+@DETECTORS.register_module
+class MaskScoringRCNN(TwoStageDetector):
+    """Mask Scoring RCNN.
+
+    https://arxiv.org/abs/1903.00241
+    """
+
+    def __init__(self,
+                 backbone,
+                 rpn_head,
+                 bbox_roi_extractor,
+                 bbox_head,
+                 mask_roi_extractor,
+                 mask_head,
+                 train_cfg,
+                 test_cfg,
+                 neck=None,
+                 shared_head=None,
+                 mask_iou_head=None,
+                 pretrained=None):
+        super(MaskScoringRCNN, self).__init__(
+            backbone=backbone,
+            neck=neck,
+            shared_head=shared_head,
+            rpn_head=rpn_head,
+            bbox_roi_extractor=bbox_roi_extractor,
+            bbox_head=bbox_head,
+            mask_roi_extractor=mask_roi_extractor,
+            mask_head=mask_head,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+            pretrained=pretrained)
+
+        self.mask_iou_head = builder.build_head(mask_iou_head)
+        self.mask_iou_head.init_weights()
+
+    # TODO: refactor forward_train in two stage to reduce code redundancy
+    def forward_train(self,
+                      img,
+                      img_meta,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None,
+                      proposals=None):
+        x = self.extract_feat(img)
+
+        losses = dict()
+
+        # RPN forward and loss
+        if self.with_rpn:
+            rpn_outs = self.rpn_head(x)
+            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
+                                          self.train_cfg.rpn)
+            rpn_losses = self.rpn_head.loss(
+                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+            losses.update(rpn_losses)
+
+            proposal_cfg = self.train_cfg.get('rpn_proposal',
+                                              self.test_cfg.rpn)
+            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
+            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
+        else:
+            proposal_list = proposals
+
+        # assign gts and sample proposals
+        if self.with_bbox or self.with_mask:
+            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
+            bbox_sampler = build_sampler(
+                self.train_cfg.rcnn.sampler, context=self)
+            num_imgs = img.size(0)
+            if gt_bboxes_ignore is None:
+                gt_bboxes_ignore = [None for _ in range(num_imgs)]
+            sampling_results = []
+            for i in range(num_imgs):
+                assign_result = bbox_assigner.assign(proposal_list[i],
+                                                     gt_bboxes[i],
+                                                     gt_bboxes_ignore[i],
+                                                     gt_labels[i])
+                sampling_result = bbox_sampler.sample(
+                    assign_result,
+                    proposal_list[i],
+                    gt_bboxes[i],
+                    gt_labels[i],
+                    feats=[lvl_feat[i][None] for lvl_feat in x])
+                sampling_results.append(sampling_result)
+
+        # bbox head forward and loss
+        if self.with_bbox:
+            rois = bbox2roi([res.bboxes for res in sampling_results])
+            # TODO: a more flexible way to decide which feature maps to use
+            bbox_feats = self.bbox_roi_extractor(
+                x[:self.bbox_roi_extractor.num_inputs], rois)
+            if self.with_shared_head:
+                bbox_feats = self.shared_head(bbox_feats)
+            cls_score, bbox_pred = self.bbox_head(bbox_feats)
+
+            bbox_targets = self.bbox_head.get_target(sampling_results,
+                                                     gt_bboxes, gt_labels,
+                                                     self.train_cfg.rcnn)
+            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
+                                            *bbox_targets)
+            losses.update(loss_bbox)
+
+        # mask head forward and loss
+        if self.with_mask:
+            if not self.share_roi_extractor:
+                pos_rois = bbox2roi(
+                    [res.pos_bboxes for res in sampling_results])
+                mask_feats = self.mask_roi_extractor(
+                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
+                if self.with_shared_head:
+                    mask_feats = self.shared_head(mask_feats)
+            else:
+                pos_inds = []
+                device = bbox_feats.device
+                for res in sampling_results:
+                    pos_inds.append(
+                        torch.ones(
+                            res.pos_bboxes.shape[0],
+                            device=device,
+                            dtype=torch.uint8))
+                    pos_inds.append(
+                        torch.zeros(
+                            res.neg_bboxes.shape[0],
+                            device=device,
+                            dtype=torch.uint8))
+                pos_inds = torch.cat(pos_inds)
+                mask_feats = bbox_feats[pos_inds]
+            mask_pred = self.mask_head(mask_feats)
+
+            mask_targets = self.mask_head.get_target(sampling_results,
+                                                     gt_masks,
+                                                     self.train_cfg.rcnn)
+            pos_labels = torch.cat(
+                [res.pos_gt_labels for res in sampling_results])
+            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
+                                            pos_labels)
+            losses.update(loss_mask)
+
+            # mask iou head forward and loss
+            pos_mask_pred = mask_pred[range(mask_pred.size(0)), pos_labels]
+            mask_iou_pred = self.mask_iou_head(mask_feats, pos_mask_pred)
+            pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)
+                                                    ), pos_labels]
+            mask_iou_targets = self.mask_iou_head.get_target(
+                sampling_results, gt_masks, pos_mask_pred, mask_targets,
+                self.train_cfg.rcnn)
+            loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred,
+                                                    mask_iou_targets)
+            losses.update(loss_mask_iou)
+        return losses
+
+    def simple_test_mask(self,
+                         x,
+                         img_meta,
+                         det_bboxes,
+                         det_labels,
+                         rescale=False):
+        # image shape of the first image in the batch (only one)
+        ori_shape = img_meta[0]['ori_shape']
+        scale_factor = img_meta[0]['scale_factor']
+
+        if det_bboxes.shape[0] == 0:
+            segm_result = [[] for _ in range(self.mask_head.num_classes - 1)]
+            mask_scores = [[] for _ in range(self.mask_head.num_classes - 1)]
+        else:
+            # if det_bboxes is rescaled to the original image size, we need to
+            # rescale it back to the testing scale to obtain RoIs.
+            _bboxes = (
+                det_bboxes[:, :4] * scale_factor if rescale else det_bboxes)
+            mask_rois = bbox2roi([_bboxes])
+            mask_feats = self.mask_roi_extractor(
+                x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois)
+            if self.with_shared_head:
+                mask_feats = self.shared_head(mask_feats)
+            mask_pred = self.mask_head(mask_feats)
+            segm_result = self.mask_head.get_seg_masks(mask_pred, _bboxes,
+                                                       det_labels,
+                                                       self.test_cfg.rcnn,
+                                                       ori_shape, scale_factor,
+                                                       rescale)
+            # get mask scores with mask iou head
+            mask_iou_pred = self.mask_iou_head(
+                mask_feats,
+                mask_pred[range(det_labels.size(0)), det_labels + 1])
+            mask_scores = self.mask_iou_head.get_mask_scores(
+                mask_iou_pred, det_bboxes, det_labels)
+        return segm_result, mask_scores
diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py
index 45920dd3c209f1f756f0294bb1465a3faca61dcd..531e5f1b85883230e919c8fd1a156f000251a88e 100644
--- a/mmdet/models/losses/__init__.py
+++ b/mmdet/models/losses/__init__.py
@@ -5,6 +5,7 @@ from .focal_loss import sigmoid_focal_loss, FocalLoss
 from .smooth_l1_loss import smooth_l1_loss, SmoothL1Loss
 from .ghm_loss import GHMC, GHMR
 from .balanced_l1_loss import balanced_l1_loss, BalancedL1Loss
+from .mse_loss import mse_loss, MSELoss
 from .iou_loss import iou_loss, bounded_iou_loss, IoULoss, BoundedIoULoss
 from .utils import reduce_loss, weight_reduce_loss, weighted_loss
 
@@ -12,7 +13,7 @@ __all__ = [
     'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
     'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
     'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss',
-    'BalancedL1Loss', 'iou_loss', 'bounded_iou_loss', 'IoULoss',
-    'BoundedIoULoss', 'GHMC', 'GHMR', 'reduce_loss', 'weight_reduce_loss',
-    'weighted_loss'
+    'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
+    'IoULoss', 'BoundedIoULoss', 'GHMC', 'GHMR', 'reduce_loss',
+    'weight_reduce_loss', 'weighted_loss'
 ]
diff --git a/mmdet/models/losses/mse_loss.py b/mmdet/models/losses/mse_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a50f4592017e4413989874ea19342edad94fead7
--- /dev/null
+++ b/mmdet/models/losses/mse_loss.py
@@ -0,0 +1,25 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import weighted_loss
+from ..registry import LOSSES
+
+mse_loss = weighted_loss(F.mse_loss)
+
+
+@LOSSES.register_module
+class MSELoss(nn.Module):
+
+    def __init__(self, reduction='mean', loss_weight=1.0):
+        super().__init__()
+        self.reduction = reduction
+        self.loss_weight = loss_weight
+
+    def forward(self, pred, target, weight=None, avg_factor=None):
+        loss = self.loss_weight * mse_loss(
+            pred,
+            target,
+            weight,
+            reduction=self.reduction,
+            avg_factor=avg_factor)
+        return loss
diff --git a/mmdet/models/mask_heads/__init__.py b/mmdet/models/mask_heads/__init__.py
index 899e4645a190d9912722305b86f9b166ed946d0f..ce1f3e7f3d5c8a6a05dc0a11ccfb410c67f02670 100644
--- a/mmdet/models/mask_heads/__init__.py
+++ b/mmdet/models/mask_heads/__init__.py
@@ -1,5 +1,6 @@
 from .fcn_mask_head import FCNMaskHead
-from .htc_mask_head import HTCMaskHead
 from .fused_semantic_head import FusedSemanticHead
+from .htc_mask_head import HTCMaskHead
+from .maskiou_head import MaskIoUHead
 
-__all__ = ['FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead']
+__all__ = ['FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'MaskIoUHead']
diff --git a/mmdet/models/mask_heads/maskiou_head.py b/mmdet/models/mask_heads/maskiou_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfa1764345d1ef399ee66d9da11760086db0076f
--- /dev/null
+++ b/mmdet/models/mask_heads/maskiou_head.py
@@ -0,0 +1,181 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import kaiming_init, normal_init
+
+from ..builder import build_loss
+from ..registry import HEADS
+
+
+@HEADS.register_module
+class MaskIoUHead(nn.Module):
+    """Mask IoU Head.
+
+    This head predicts the IoU of predicted masks and corresponding gt masks.
+    """
+
+    def __init__(self,
+                 num_convs=4,
+                 num_fcs=2,
+                 roi_feat_size=14,
+                 in_channels=256,
+                 conv_out_channels=256,
+                 fc_out_channels=1024,
+                 num_classes=81,
+                 loss_iou=dict(type='MSELoss', loss_weight=0.5)):
+        super(MaskIoUHead, self).__init__()
+        self.in_channels = in_channels
+        self.conv_out_channels = conv_out_channels
+        self.fc_out_channels = fc_out_channels
+        self.num_classes = num_classes
+
+        self.convs = nn.ModuleList()
+        for i in range(num_convs):
+            if i == 0:
+                # concatenation of mask feature and mask prediction
+                in_channels = self.in_channels + 1
+            else:
+                in_channels = self.conv_out_channels
+            stride = 2 if i == num_convs - 1 else 1
+            self.convs.append(
+                nn.Conv2d(
+                    in_channels,
+                    self.conv_out_channels,
+                    3,
+                    stride=stride,
+                    padding=1))
+
+        self.fcs = nn.ModuleList()
+        for i in range(num_fcs):
+            in_channels = self.conv_out_channels * (
+                roi_feat_size // 2)**2 if i == 0 else self.fc_out_channels
+            self.fcs.append(nn.Linear(in_channels, self.fc_out_channels))
+
+        self.fc_mask_iou = nn.Linear(self.fc_out_channels, self.num_classes)
+        self.relu = nn.ReLU()
+        self.max_pool = nn.MaxPool2d(2, 2)
+        self.loss_iou = build_loss(loss_iou)
+
+    def init_weights(self):
+        for conv in self.convs:
+            kaiming_init(conv)
+        for fc in self.fcs:
+            kaiming_init(
+                fc,
+                a=1,
+                mode='fan_in',
+                nonlinearity='leaky_relu',
+                distribution='uniform')
+        normal_init(self.fc_mask_iou, std=0.01)
+
+    def forward(self, mask_feat, mask_pred):
+        mask_pred = mask_pred.sigmoid()
+        mask_pred_pooled = self.max_pool(mask_pred.unsqueeze(1))
+
+        x = torch.cat((mask_feat, mask_pred_pooled), 1)
+
+        for conv in self.convs:
+            x = self.relu(conv(x))
+        x = x.view(x.size(0), -1)
+        for fc in self.fcs:
+            x = self.relu(fc(x))
+        mask_iou = self.fc_mask_iou(x)
+        return mask_iou
+
+    def loss(self, mask_iou_pred, mask_iou_targets):
+        pos_inds = mask_iou_targets > 0
+        if pos_inds.sum() > 0:
+            loss_mask_iou = self.loss_iou(mask_iou_pred[pos_inds],
+                                          mask_iou_targets[pos_inds])
+        else:
+            loss_mask_iou = mask_iou_pred * 0
+        return dict(loss_mask_iou=loss_mask_iou)
+
+    def get_target(self, sampling_results, gt_masks, mask_pred, mask_targets,
+                   rcnn_train_cfg):
+        """Compute target of mask IoU.
+
+        Mask IoU target is the IoU of the predicted mask (inside a bbox) and
+        the gt mask of corresponding gt mask (the whole instance).
+        The intersection area is computed inside the bbox, and the gt mask area
+        is computed with two steps, firstly we compute the gt area inside the
+        bbox, then divide it by the area ratio of gt area inside the bbox and
+        the gt area of the whole instance.
+
+        Args:
+            sampling_results (list[:obj:`SamplingResult`]): sampling results.
+            gt_masks (list[ndarray]): Gt masks (the whole instance) of each
+                image, binary maps with the same shape of the input image.
+            mask_pred (Tensor): Predicted masks of each positive proposal,
+                shape (num_pos, h, w).
+            mask_targets (Tensor): Gt mask of each positive proposal,
+                binary map of the shape (num_pos, h, w).
+            rcnn_train_cfg (dict): Training config for R-CNN part.
+
+        Returns:
+            Tensor: mask iou target (length == num positive).
+        """
+        pos_proposals = [res.pos_bboxes for res in sampling_results]
+        pos_assigned_gt_inds = [
+            res.pos_assigned_gt_inds for res in sampling_results
+        ]
+
+        # compute the area ratio of gt areas inside the proposals and
+        # the whole instance
+        area_ratios = map(self._get_area_ratio, pos_proposals,
+                          pos_assigned_gt_inds, gt_masks)
+        area_ratios = torch.cat(list(area_ratios))
+        assert mask_targets.size(0) == area_ratios.size(0)
+
+        mask_pred = (mask_pred > rcnn_train_cfg.mask_thr_binary).float()
+        mask_pred_areas = mask_pred.sum((-1, -2))
+
+        # mask_pred and mask_targets are binary maps
+        overlap_areas = (mask_pred * mask_targets).sum((-1, -2))
+
+        # compute the mask area of the whole instance
+        gt_full_areas = mask_targets.sum((-1, -2)) / (area_ratios + 1e-7)
+
+        mask_iou_targets = overlap_areas / (
+            mask_pred_areas + gt_full_areas - overlap_areas)
+        return mask_iou_targets
+
+    def _get_area_ratio(self, pos_proposals, pos_assigned_gt_inds, gt_masks):
+        """Compute area ratio of the gt mask inside the proposal and the gt
+        mask of the corresponding instance"""
+        num_pos = pos_proposals.size(0)
+        if num_pos > 0:
+            area_ratios = []
+            proposals_np = pos_proposals.cpu().numpy()
+            pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
+            # compute mask areas of gt instances (batch processing for speedup)
+            gt_instance_mask_area = gt_masks.sum((-1, -2))
+            for i in range(num_pos):
+                gt_mask = gt_masks[pos_assigned_gt_inds[i]]
+
+                # crop the gt mask inside the proposal
+                x1, y1, x2, y2 = proposals_np[i, :].astype(np.int32)
+                gt_mask_in_proposal = gt_mask[y1:y2 + 1, x1:x2 + 1]
+
+                ratio = gt_mask_in_proposal.sum() / (
+                    gt_instance_mask_area[pos_assigned_gt_inds[i]] + 1e-7)
+                area_ratios.append(ratio)
+            area_ratios = torch.from_numpy(np.stack(area_ratios)).float().to(
+                pos_proposals.device)
+        else:
+            area_ratios = pos_proposals.new_zeros((0, ))
+        return area_ratios
+
+    def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels):
+        """Get the mask scores.
+
+        mask_score = bbox_score * mask_iou
+        """
+        inds = range(det_labels.size(0))
+        mask_scores = mask_iou_pred[inds, det_labels +
+                                    1] * det_bboxes[inds, -1]
+        mask_scores = mask_scores.cpu().numpy()
+        det_labels = det_labels.cpu().numpy()
+        return [
+            mask_scores[det_labels == i] for i in range(self.num_classes - 1)
+        ]