diff --git a/configs/double_heads/dh_faster_rcnn_r50_fpn_1x.py b/configs/double_heads/dh_faster_rcnn_r50_fpn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..cef46570279d5835ddad67b2569ecc5ec8ccfad4
--- /dev/null
+++ b/configs/double_heads/dh_faster_rcnn_r50_fpn_1x.py
@@ -0,0 +1,170 @@
+# model settings
+model = dict(
+    type='DoubleHeadRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    reg_roi_scale_factor=1.3,
+    bbox_head=dict(
+        type='DoubleConvFCBBoxHead',
+        num_convs=4,
+        num_fcs=2,
+        in_channels=256,
+        conv_out_channels=1024,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/dh_faster_rcnn_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/bbox_heads/__init__.py b/mmdet/models/bbox_heads/__init__.py
index dac10bdc144636ccb3953db5dfccfc1a583ceec2..a668bdb0182361b5213adb877a21246e23ab4ede 100644
--- a/mmdet/models/bbox_heads/__init__.py
+++ b/mmdet/models/bbox_heads/__init__.py
@@ -1,4 +1,7 @@
 from .bbox_head import BBoxHead
 from .convfc_bbox_head import ConvFCBBoxHead, SharedFCBBoxHead
+from .double_bbox_head import DoubleConvFCBBoxHead
 
-__all__ = ['BBoxHead', 'ConvFCBBoxHead', 'SharedFCBBoxHead']
+__all__ = [
+    'BBoxHead', 'ConvFCBBoxHead', 'SharedFCBBoxHead', 'DoubleConvFCBBoxHead'
+]
diff --git a/mmdet/models/bbox_heads/double_bbox_head.py b/mmdet/models/bbox_heads/double_bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7211d9b49373f2c646e6c30c8c02f60b6c8f7a35
--- /dev/null
+++ b/mmdet/models/bbox_heads/double_bbox_head.py
@@ -0,0 +1,170 @@
+import torch.nn as nn
+from mmcv.cnn.weight_init import normal_init, xavier_init
+
+from .bbox_head import BBoxHead
+from ..backbones.resnet import Bottleneck
+from ..registry import HEADS
+from ..utils import ConvModule
+
+
+class BasicResBlock(nn.Module):
+    """Basic residual block.
+
+    This block is a little different from the block in the ResNet backbone.
+    The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock.
+
+    Args:
+        in_channels (int): Channels of the input feature map.
+        out_channels (int): Channels of the output feature map.
+        conv_cfg (dict): The config dict for convolution layers.
+        norm_cfg (dict): The config dict for normalization layers.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN')):
+        super(BasicResBlock, self).__init__()
+
+        # main path
+        self.conv1 = ConvModule(
+            in_channels,
+            in_channels,
+            kernel_size=3,
+            padding=1,
+            bias=False,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg)
+        self.conv2 = ConvModule(
+            in_channels,
+            out_channels,
+            kernel_size=1,
+            bias=False,
+            activation=None,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg)
+
+        # identity path
+        self.conv_identity = ConvModule(
+            in_channels,
+            out_channels,
+            kernel_size=1,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            activation=None)
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        identity = x
+
+        x = self.conv1(x)
+        x = self.conv2(x)
+
+        identity = self.conv_identity(identity)
+        out = x + identity
+
+        out = self.relu(out)
+        return out
+
+
+@HEADS.register_module
+class DoubleConvFCBBoxHead(BBoxHead):
+    """Bbox head used in Double-Head R-CNN
+
+                                      /-> cls
+                  /-> shared convs ->
+                                      \-> reg
+    roi features
+                                      /-> cls
+                  \-> shared fc    ->
+                                      \-> reg
+    """  # noqa: W605
+
+    def __init__(self,
+                 num_convs=0,
+                 num_fcs=0,
+                 conv_out_channels=1024,
+                 fc_out_channels=1024,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 **kwargs):
+        kwargs.setdefault('with_avg_pool', True)
+        super(DoubleConvFCBBoxHead, self).__init__(**kwargs)
+        assert self.with_avg_pool
+        assert num_convs > 0
+        assert num_fcs > 0
+        self.num_convs = num_convs
+        self.num_fcs = num_fcs
+        self.conv_out_channels = conv_out_channels
+        self.fc_out_channels = fc_out_channels
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+
+        # increase the channel of input features
+        self.res_block = BasicResBlock(self.in_channels,
+                                       self.conv_out_channels)
+
+        # add conv heads
+        self.conv_branch = self._add_conv_branch()
+        # add fc heads
+        self.fc_branch = self._add_fc_branch()
+
+        out_dim_reg = 4 if self.reg_class_agnostic else 4 * self.num_classes
+        self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg)
+
+        self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes)
+        self.relu = nn.ReLU(inplace=True)
+
+    def _add_conv_branch(self):
+        """Add the fc branch which consists of a sequential of conv layers"""
+        branch_convs = nn.ModuleList()
+        for i in range(self.num_convs):
+            branch_convs.append(
+                Bottleneck(
+                    inplanes=self.conv_out_channels,
+                    planes=self.conv_out_channels // 4,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg))
+        return branch_convs
+
+    def _add_fc_branch(self):
+        """Add the fc branch which consists of a sequential of fc layers"""
+        branch_fcs = nn.ModuleList()
+        for i in range(self.num_fcs):
+            fc_in_channels = (
+                self.in_channels * self.roi_feat_size *
+                self.roi_feat_size if i == 0 else self.fc_out_channels)
+            branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))
+        return branch_fcs
+
+    def init_weights(self):
+        normal_init(self.fc_cls, std=0.01)
+        normal_init(self.fc_reg, std=0.001)
+
+        for m in self.fc_branch.modules():
+            if isinstance(m, nn.Linear):
+                xavier_init(m, distribution='uniform')
+
+    def forward(self, x_cls, x_reg):
+        # conv head
+        x_conv = self.res_block(x_reg)
+
+        for conv in self.conv_branch:
+            x_conv = conv(x_conv)
+
+        if self.with_avg_pool:
+            x_conv = self.avg_pool(x_conv)
+
+        x_conv = x_conv.view(x_conv.size(0), -1)
+        bbox_pred = self.fc_reg(x_conv)
+
+        # fc head
+        x_fc = x_cls.view(x_cls.size(0), -1)
+        for fc in self.fc_branch:
+            x_fc = self.relu(fc(x_fc))
+
+        cls_score = self.fc_cls(x_fc)
+
+        return cls_score, bbox_pred
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index 9917f5d50310ee10fbf5d3be28fd3e5dde109efa..4d97b49e2e93c8dabb4a263bf54c2c5bd5fe66e5 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -6,6 +6,7 @@ from .fast_rcnn import FastRCNN
 from .faster_rcnn import FasterRCNN
 from .mask_rcnn import MaskRCNN
 from .cascade_rcnn import CascadeRCNN
+from .double_head_rcnn import DoubleHeadRCNN
 from .htc import HybridTaskCascade
 from .retinanet import RetinaNet
 from .fcos import FCOS
@@ -15,5 +16,5 @@ from .mask_scoring_rcnn import MaskScoringRCNN
 __all__ = [
     'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
     'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
-    'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN'
+    'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN'
 ]
diff --git a/mmdet/models/detectors/double_head_rcnn.py b/mmdet/models/detectors/double_head_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..50001872713a3ea57750cda28d1805be8ad434dc
--- /dev/null
+++ b/mmdet/models/detectors/double_head_rcnn.py
@@ -0,0 +1,154 @@
+import torch
+
+from .two_stage import TwoStageDetector
+from ..registry import DETECTORS
+from mmdet.core import bbox2roi, build_assigner, build_sampler
+
+
+@DETECTORS.register_module
+class DoubleHeadRCNN(TwoStageDetector):
+
+    def __init__(self, reg_roi_scale_factor, **kwargs):
+        super().__init__(**kwargs)
+        self.reg_roi_scale_factor = reg_roi_scale_factor
+
+    def forward_train(self,
+                      img,
+                      img_meta,
+                      gt_bboxes,
+                      gt_labels,
+                      gt_bboxes_ignore=None,
+                      gt_masks=None,
+                      proposals=None):
+        x = self.extract_feat(img)
+
+        losses = dict()
+
+        # RPN forward and loss
+        if self.with_rpn:
+            rpn_outs = self.rpn_head(x)
+            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
+                                          self.train_cfg.rpn)
+            rpn_losses = self.rpn_head.loss(
+                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+            losses.update(rpn_losses)
+
+            proposal_cfg = self.train_cfg.get('rpn_proposal',
+                                              self.test_cfg.rpn)
+            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
+            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
+        else:
+            proposal_list = proposals
+
+        # assign gts and sample proposals
+        if self.with_bbox or self.with_mask:
+            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
+            bbox_sampler = build_sampler(
+                self.train_cfg.rcnn.sampler, context=self)
+            num_imgs = img.size(0)
+            if gt_bboxes_ignore is None:
+                gt_bboxes_ignore = [None for _ in range(num_imgs)]
+            sampling_results = []
+            for i in range(num_imgs):
+                assign_result = bbox_assigner.assign(proposal_list[i],
+                                                     gt_bboxes[i],
+                                                     gt_bboxes_ignore[i],
+                                                     gt_labels[i])
+                sampling_result = bbox_sampler.sample(
+                    assign_result,
+                    proposal_list[i],
+                    gt_bboxes[i],
+                    gt_labels[i],
+                    feats=[lvl_feat[i][None] for lvl_feat in x])
+                sampling_results.append(sampling_result)
+
+        # bbox head forward and loss
+        if self.with_bbox:
+            rois = bbox2roi([res.bboxes for res in sampling_results])
+            # TODO: a more flexible way to decide which feature maps to use
+            bbox_cls_feats = self.bbox_roi_extractor(
+                x[:self.bbox_roi_extractor.num_inputs], rois)
+            bbox_reg_feats = self.bbox_roi_extractor(
+                x[:self.bbox_roi_extractor.num_inputs],
+                rois,
+                roi_scale_factor=self.reg_roi_scale_factor)
+            if self.with_shared_head:
+                bbox_cls_feats = self.shared_head(bbox_cls_feats)
+                bbox_reg_feats = self.shared_head(bbox_reg_feats)
+            cls_score, bbox_pred = self.bbox_head(bbox_cls_feats,
+                                                  bbox_reg_feats)
+
+            bbox_targets = self.bbox_head.get_target(sampling_results,
+                                                     gt_bboxes, gt_labels,
+                                                     self.train_cfg.rcnn)
+            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
+                                            *bbox_targets)
+            losses.update(loss_bbox)
+
+        # mask head forward and loss
+        if self.with_mask:
+            if not self.share_roi_extractor:
+                pos_rois = bbox2roi(
+                    [res.pos_bboxes for res in sampling_results])
+                mask_feats = self.mask_roi_extractor(
+                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
+                if self.with_shared_head:
+                    mask_feats = self.shared_head(mask_feats)
+            else:
+                pos_inds = []
+                device = bbox_cls_feats.device
+                for res in sampling_results:
+                    pos_inds.append(
+                        torch.ones(
+                            res.pos_bboxes.shape[0],
+                            device=device,
+                            dtype=torch.uint8))
+                    pos_inds.append(
+                        torch.zeros(
+                            res.neg_bboxes.shape[0],
+                            device=device,
+                            dtype=torch.uint8))
+                pos_inds = torch.cat(pos_inds)
+                mask_feats = bbox_cls_feats[pos_inds]
+            mask_pred = self.mask_head(mask_feats)
+
+            mask_targets = self.mask_head.get_target(sampling_results,
+                                                     gt_masks,
+                                                     self.train_cfg.rcnn)
+            pos_labels = torch.cat(
+                [res.pos_gt_labels for res in sampling_results])
+            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
+                                            pos_labels)
+            losses.update(loss_mask)
+
+        return losses
+
+    def simple_test_bboxes(self,
+                           x,
+                           img_meta,
+                           proposals,
+                           rcnn_test_cfg,
+                           rescale=False):
+        """Test only det bboxes without augmentation."""
+        rois = bbox2roi(proposals)
+        bbox_cls_feats = self.bbox_roi_extractor(
+            x[:self.bbox_roi_extractor.num_inputs], rois)
+        bbox_reg_feats = self.bbox_roi_extractor(
+            x[:self.bbox_roi_extractor.num_inputs],
+            rois,
+            roi_scale_factor=self.reg_roi_scale_factor)
+        if self.with_shared_head:
+            bbox_cls_feats = self.shared_head(bbox_cls_feats)
+            bbox_reg_feats = self.shared_head(bbox_reg_feats)
+        cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
+        img_shape = img_meta[0]['img_shape']
+        scale_factor = img_meta[0]['scale_factor']
+        det_bboxes, det_labels = self.bbox_head.get_det_bboxes(
+            rois,
+            cls_score,
+            bbox_pred,
+            img_shape,
+            scale_factor,
+            rescale=rescale,
+            cfg=rcnn_test_cfg)
+        return det_bboxes, det_labels
diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py
index 7eaab120e6e1a6a2b1243752b23348f3b0cca713..503818ef368179cb0258645941725c3e52c9d4ac 100644
--- a/mmdet/models/roi_extractors/single_level.py
+++ b/mmdet/models/roi_extractors/single_level.py
@@ -72,8 +72,22 @@ class SingleRoIExtractor(nn.Module):
         target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
         return target_lvls
 
+    def roi_rescale(self, rois, scale_factor):
+        cx = (rois[:, 1] + rois[:, 3]) * 0.5
+        cy = (rois[:, 2] + rois[:, 4]) * 0.5
+        w = rois[:, 3] - rois[:, 1] + 1
+        h = rois[:, 4] - rois[:, 2] + 1
+        new_w = w * scale_factor
+        new_h = h * scale_factor
+        x1 = cx - new_w * 0.5 + 0.5
+        x2 = cx + new_w * 0.5 - 0.5
+        y1 = cy - new_h * 0.5 + 0.5
+        y2 = cy + new_h * 0.5 - 0.5
+        new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
+        return new_rois
+
     @force_fp32(apply_to=('feats', ), out_fp16=True)
-    def forward(self, feats, rois):
+    def forward(self, feats, rois, roi_scale_factor=None):
         if len(feats) == 1:
             return self.roi_layers[0](feats[0], rois)
 
@@ -82,6 +96,8 @@ class SingleRoIExtractor(nn.Module):
         target_lvls = self.map_roi_levels(rois, num_levels)
         roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels,
                                        out_size, out_size)
+        if roi_scale_factor is not None:
+            rois = self.roi_rescale(rois, roi_scale_factor)
         for i in range(num_levels):
             inds = target_lvls == i
             if inds.any():