diff --git a/configs/faster_rcnn_r50_fpn_1x.py b/configs/faster_rcnn_r50_fpn_1x.py
index f4803f0b045e3801d2a09b652d6869625fb589f0..b15405e0997abfd9aef23b3317afc55da6c141dd 100644
--- a/configs/faster_rcnn_r50_fpn_1x.py
+++ b/configs/faster_rcnn_r50_fpn_1x.py
@@ -3,7 +3,7 @@ model = dict(
     type='FasterRCNN',
     pretrained='modelzoo://resnet50',
     backbone=dict(
-        type='resnet',
+        type='ResNet',
         depth=50,
         num_stages=4,
         out_indices=(0, 1, 2, 3),
diff --git a/configs/mask_rcnn_r50_fpn_1x.py b/configs/mask_rcnn_r50_fpn_1x.py
index 4760821e24464b2e21d5ac0b0b0418f4163e9494..e2d47217cc4bfac3ee8f52450a74a042fb2d8189 100644
--- a/configs/mask_rcnn_r50_fpn_1x.py
+++ b/configs/mask_rcnn_r50_fpn_1x.py
@@ -3,7 +3,7 @@ model = dict(
     type='MaskRCNN',
     pretrained='modelzoo://resnet50',
     backbone=dict(
-        type='resnet',
+        type='ResNet',
         depth=50,
         num_stages=4,
         out_indices=(0, 1, 2, 3),
diff --git a/configs/rpn_r50_fpn_1x.py b/configs/rpn_r50_fpn_1x.py
index 4e45eb9e41b8b727256b2abfe974e12802b73560..7f1b6d0ca39558694292610ba93979099eb0ada8 100644
--- a/configs/rpn_r50_fpn_1x.py
+++ b/configs/rpn_r50_fpn_1x.py
@@ -3,7 +3,7 @@ model = dict(
     type='RPN',
     pretrained='modelzoo://resnet50',
     backbone=dict(
-        type='resnet',
+        type='ResNet',
         depth=50,
         num_stages=4,
         out_indices=(0, 1, 2, 3),
diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py
index 107507ceaf6d1a36cafe07197cefd9693a13a49b..0f82f92aad10ed86b6528f0554615d7e9589ce1c 100644
--- a/mmdet/models/backbones/__init__.py
+++ b/mmdet/models/backbones/__init__.py
@@ -1,3 +1,3 @@
-from .resnet import resnet
+from .resnet import ResNet
 
-__all__ = ['resnet']
+__all__ = ['ResNet']
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index 371f4f59feca466eca0040faeb1ae7de5e78800f..66684b154b5aea3364789495b43c8b31ab97745b 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -1,8 +1,9 @@
 import logging
-import math
 
 import torch.nn as nn
 import torch.utils.checkpoint as cp
+
+from mmcv.cnn import constant_init, kaiming_init
 from mmcv.runner import load_checkpoint
 
 
@@ -27,7 +28,8 @@ class BasicBlock(nn.Module):
                  stride=1,
                  dilation=1,
                  downsample=None,
-                 style='pytorch'):
+                 style='pytorch',
+                 with_cp=False):
         super(BasicBlock, self).__init__()
         self.conv1 = conv3x3(inplanes, planes, stride, dilation)
         self.bn1 = nn.BatchNorm2d(planes)
@@ -37,6 +39,7 @@ class BasicBlock(nn.Module):
         self.downsample = downsample
         self.stride = stride
         self.dilation = dilation
+        assert not with_cp
 
     def forward(self, x):
         residual = x
@@ -69,7 +72,6 @@ class Bottleneck(nn.Module):
                  style='pytorch',
                  with_cp=False):
         """Bottleneck block.
-
         If style is "pytorch", the stride-two layer is the 3x3 conv layer,
         if it is "caffe", the stride-two layer is the first 1x1 conv layer.
         """
@@ -174,64 +176,73 @@ def make_res_layer(block,
     return nn.Sequential(*layers)
 
 
-class ResHead(nn.Module):
-
-    def __init__(self,
-                 block,
-                 num_blocks,
-                 stride=2,
-                 dilation=1,
-                 style='pytorch'):
-        self.layer4 = make_res_layer(
-            block,
-            1024,
-            512,
-            num_blocks,
-            stride=stride,
-            dilation=dilation,
-            style=style)
-
-    def forward(self, x):
-        return self.layer4(x)
+class ResNet(nn.Module):
+    """ResNet backbone.
 
+    Args:
+        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+        num_stages (int): Resnet stages, normally 4.
+        strides (Sequence[int]): Strides of the first block of each stage.
+        dilations (Sequence[int]): Dilation of each stage.
+        out_indices (Sequence[int]): Output from which stages.
+        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+            layer is the 3x3 conv layer, otherwise the stride-two layer is
+            the first 1x1 conv layer.
+        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+            not freezing any parameters.
+        bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
+            running stats (mean and var).
+        bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed.
+    """
 
-class ResNet(nn.Module):
+    arch_settings = {
+        18: (BasicBlock, (2, 2, 2, 2)),
+        34: (BasicBlock, (3, 4, 6, 3)),
+        50: (Bottleneck, (3, 4, 6, 3)),
+        101: (Bottleneck, (3, 4, 23, 3)),
+        152: (Bottleneck, (3, 8, 36, 3))
+    }
 
     def __init__(self,
-                 block,
-                 layers,
+                 depth,
+                 num_stages=4,
                  strides=(1, 2, 2, 2),
                  dilations=(1, 1, 1, 1),
                  out_indices=(0, 1, 2, 3),
-                 frozen_stages=-1,
                  style='pytorch',
-                 sync_bn=False,
-                 with_cp=False,
-                 strict_frozen=False):
+                 frozen_stages=-1,
+                 bn_eval=True,
+                 bn_frozen=False,
+                 with_cp=False):
         super(ResNet, self).__init__()
-        if not len(layers) == len(strides) == len(dilations):
-            raise ValueError(
-                'The number of layers, strides and dilations must be equal, '
-                'but found have {} layers, {} strides and {} dilations'.format(
-                    len(layers), len(strides), len(dilations)))
-        assert max(out_indices) < len(layers)
+        if depth not in self.arch_settings:
+            raise KeyError('invalid depth {} for resnet'.format(depth))
+        assert num_stages >= 1 and num_stages <= 4
+        block, stage_blocks = self.arch_settings[depth]
+        stage_blocks = stage_blocks[:num_stages]
+        assert len(strides) == len(dilations) == num_stages
+        assert max(out_indices) < num_stages
+
         self.out_indices = out_indices
-        self.frozen_stages = frozen_stages
         self.style = style
-        self.sync_bn = sync_bn
+        self.frozen_stages = frozen_stages
+        self.bn_eval = bn_eval
+        self.bn_frozen = bn_frozen
+        self.with_cp = with_cp
+
         self.inplanes = 64
         self.conv1 = nn.Conv2d(
             3, 64, kernel_size=7, stride=2, padding=3, bias=False)
         self.bn1 = nn.BatchNorm2d(64)
         self.relu = nn.ReLU(inplace=True)
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        self.res_layers = []
-        for i, num_blocks in enumerate(layers):
 
+        self.res_layers = []
+        for i, num_blocks in enumerate(stage_blocks):
             stride = strides[i]
             dilation = dilations[i]
-
-            layer_name = 'layer{}'.format(i + 1)
             planes = 64 * 2**i
             res_layer = make_res_layer(
                 block,
@@ -243,12 +254,11 @@ class ResNet(nn.Module):
                 style=self.style,
                 with_cp=with_cp)
             self.inplanes = planes * block.expansion
+            layer_name = 'layer{}'.format(i + 1)
             self.add_module(layer_name, res_layer)
             self.res_layers.append(layer_name)
-        self.feat_dim = block.expansion * 64 * 2**(len(layers) - 1)
-        self.with_cp = with_cp
 
-        self.strict_frozen = strict_frozen
+        self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
 
     def init_weights(self, pretrained=None):
         if isinstance(pretrained, str):
@@ -257,11 +267,9 @@ class ResNet(nn.Module):
         elif pretrained is None:
             for m in self.modules():
                 if isinstance(m, nn.Conv2d):
-                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
-                    nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
+                    kaiming_init(m)
                 elif isinstance(m, nn.BatchNorm2d):
-                    nn.init.constant_(m.weight, 1)
-                    nn.init.constant_(m.bias, 0)
+                    constant_init(m, 1)
         else:
             raise TypeError('pretrained must be a str or None')
 
@@ -283,11 +291,11 @@ class ResNet(nn.Module):
 
     def train(self, mode=True):
         super(ResNet, self).train(mode)
-        if not self.sync_bn:
+        if self.bn_eval:
             for m in self.modules():
                 if isinstance(m, nn.BatchNorm2d):
                     m.eval()
-                    if self.strict_frozen:
+                    if self.bn_frozen:
                         for params in m.parameters():
                             params.requires_grad = False
         if mode and self.frozen_stages >= 0:
@@ -303,39 +311,3 @@ class ResNet(nn.Module):
                 mod.eval()
                 for param in mod.parameters():
                     param.requires_grad = False
-
-
-resnet_cfg = {
-    18: (BasicBlock, (2, 2, 2, 2)),
-    34: (BasicBlock, (3, 4, 6, 3)),
-    50: (Bottleneck, (3, 4, 6, 3)),
-    101: (Bottleneck, (3, 4, 23, 3)),
-    152: (Bottleneck, (3, 8, 36, 3))
-}
-
-
-def resnet(depth,
-           num_stages=4,
-           strides=(1, 2, 2, 2),
-           dilations=(1, 1, 1, 1),
-           out_indices=(2, ),
-           frozen_stages=-1,
-           style='pytorch',
-           sync_bn=False,
-           with_cp=False,
-           strict_frozen=False):
-    """Constructs a ResNet model.
-
-    Args:
-        depth (int): depth of resnet, from {18, 34, 50, 101, 152}
-        num_stages (int): num of resnet stages, normally 4
-        strides (list): strides of the first block of each stage
-        dilations (list): dilation of each stage
-        out_indices (list): output from which stages
-    """
-    if depth not in resnet_cfg:
-        raise KeyError('invalid depth {} for resnet'.format(depth))
-    block, layers = resnet_cfg[depth]
-    model = ResNet(block, layers[:num_stages], strides, dilations, out_indices,
-                   frozen_stages, style, sync_bn, with_cp, strict_frozen)
-    return model