diff --git a/mmdet/ops/__init__.py b/mmdet/ops/__init__.py
index 21c2f327d019f05fd5a32575bb0c4ad71fe39ef8..9576335a3063a1e05ccc69b78fe28c5a85c16acb 100644
--- a/mmdet/ops/__init__.py
+++ b/mmdet/ops/__init__.py
@@ -1,14 +1,14 @@
-from .dcn import (DeformConv, DeformRoIPooling, DeformRoIPoolingPack,
-                  ModulatedDeformRoIPoolingPack, ModulatedDeformConv,
-                  ModulatedDeformConvPack, deform_conv, modulated_deform_conv,
-                  deform_roi_pooling)
+from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
+                  ModulatedDeformConvPack, DeformRoIPooling,
+                  DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack,
+                  deform_conv, modulated_deform_conv, deform_roi_pooling)
 from .nms import nms, soft_nms
 from .roi_align import RoIAlign, roi_align
 from .roi_pool import RoIPool, roi_pool
 
 __all__ = [
     'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
-    'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
+    'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
     'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
     'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
     'deform_roi_pooling'
diff --git a/mmdet/ops/dcn/__init__.py b/mmdet/ops/dcn/__init__.py
index 1e158d01b04fbc01278021de19ed641bc6fde414..165e63725354de429a448d866f665cccca991916 100644
--- a/mmdet/ops/dcn/__init__.py
+++ b/mmdet/ops/dcn/__init__.py
@@ -1,13 +1,13 @@
 from .functions.deform_conv import deform_conv, modulated_deform_conv
 from .functions.deform_pool import deform_roi_pooling
 from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
-                                  ModulatedDeformConvPack)
+                                  DeformConvPack, ModulatedDeformConvPack)
 from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
                                   ModulatedDeformRoIPoolingPack)
 
 __all__ = [
-    'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
-    'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
-    'ModulatedDeformConvPack', 'deform_conv',
-    'modulated_deform_conv', 'deform_roi_pooling'
+    'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
+    'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
+    'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
+    'deform_roi_pooling'
 ]
diff --git a/mmdet/ops/dcn/modules/deform_conv.py b/mmdet/ops/dcn/modules/deform_conv.py
index e1b0a63c7a0b65ecf91362c5c2469d94635cc82c..016ad16647257f9497b2bdc6fc192eb7ada839e5 100644
--- a/mmdet/ops/dcn/modules/deform_conv.py
+++ b/mmdet/ops/dcn/modules/deform_conv.py
@@ -19,15 +19,16 @@ class DeformConv(nn.Module):
                  groups=1,
                  deformable_groups=1,
                  bias=False):
-        assert not bias
         super(DeformConv, self).__init__()
 
+        assert not bias
         assert in_channels % groups == 0, \
             'in_channels {} cannot be divisible by groups {}'.format(
                 in_channels, groups)
         assert out_channels % groups == 0, \
             'out_channels {} cannot be divisible by groups {}'.format(
                 out_channels, groups)
+
         self.in_channels = in_channels
         self.out_channels = out_channels
         self.kernel_size = _pair(kernel_size)
@@ -50,10 +51,34 @@ class DeformConv(nn.Module):
         stdv = 1. / math.sqrt(n)
         self.weight.data.uniform_(-stdv, stdv)
 
-    def forward(self, input, offset):
-        return deform_conv(input, offset, self.weight, self.stride,
-                           self.padding, self.dilation, self.groups,
-                           self.deformable_groups)
+    def forward(self, x, offset):
+        return deform_conv(x, offset, self.weight, self.stride, self.padding,
+                           self.dilation, self.groups, self.deformable_groups)
+
+
+class DeformConvPack(DeformConv):
+
+    def __init__(self, *args, **kwargs):
+        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+        self.conv_offset = nn.Conv2d(
+            self.in_channels,
+            self.deformable_groups * 2 * self.kernel_size[0] *
+            self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=_pair(self.stride),
+            padding=_pair(self.padding),
+            bias=True)
+        self.init_offset()
+
+    def init_offset(self):
+        self.conv_offset.weight.data.zero_()
+        self.conv_offset.bias.data.zero_()
+
+    def forward(self, x):
+        offset = self.conv_offset(x)
+        return deform_conv(x, offset, self.weight, self.stride, self.padding,
+                           self.dilation, self.groups, self.deformable_groups)
 
 
 class ModulatedDeformConv(nn.Module):
@@ -97,30 +122,19 @@ class ModulatedDeformConv(nn.Module):
         if self.bias is not None:
             self.bias.data.zero_()
 
-    def forward(self, input, offset, mask):
-        return modulated_deform_conv(
-            input, offset, mask, self.weight, self.bias, self.stride,
-            self.padding, self.dilation, self.groups, self.deformable_groups)
+    def forward(self, x, offset, mask):
+        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
+                                     self.stride, self.padding, self.dilation,
+                                     self.groups, self.deformable_groups)
 
 
 class ModulatedDeformConvPack(ModulatedDeformConv):
 
-    def __init__(self,
-                 in_channels,
-                 out_channels,
-                 kernel_size,
-                 stride=1,
-                 padding=0,
-                 dilation=1,
-                 groups=1,
-                 deformable_groups=1,
-                 bias=True):
-        super(ModulatedDeformConvPack, self).__init__(
-            in_channels, out_channels, kernel_size, stride, padding, dilation,
-            groups, deformable_groups, bias)
+    def __init__(self, *args, **kwargs):
+        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
 
         self.conv_offset_mask = nn.Conv2d(
-            self.in_channels // self.groups,
+            self.in_channels,
             self.deformable_groups * 3 * self.kernel_size[0] *
             self.kernel_size[1],
             kernel_size=self.kernel_size,
@@ -133,11 +147,11 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
         self.conv_offset_mask.weight.data.zero_()
         self.conv_offset_mask.bias.data.zero_()
 
-    def forward(self, input):
-        out = self.conv_offset_mask(input)
+    def forward(self, x):
+        out = self.conv_offset_mask(x)
         o1, o2, mask = torch.chunk(out, 3, dim=1)
         offset = torch.cat((o1, o2), dim=1)
         mask = torch.sigmoid(mask)
-        return modulated_deform_conv(
-            input, offset, mask, self.weight, self.bias, self.stride,
-            self.padding, self.dilation, self.groups, self.deformable_groups)
+        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
+                                     self.stride, self.padding, self.dilation,
+                                     self.groups, self.deformable_groups)
diff --git a/mmdet/ops/dcn/modules/deform_pool.py b/mmdet/ops/dcn/modules/deform_pool.py
index b7f6a2843bad14dfa4c70f7121e640fab7ffe67c..5e0196753ee1b427263bc397e0ae842af6a9938b 100644
--- a/mmdet/ops/dcn/modules/deform_pool.py
+++ b/mmdet/ops/dcn/modules/deform_pool.py
@@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling):
                  part_size=None,
                  sample_per_part=4,
                  trans_std=.0,
+                 num_offset_fcs=3,
                  deform_fc_channels=1024):
         super(DeformRoIPoolingPack,
               self).__init__(spatial_scale, out_size, out_channels, no_trans,
                              group_size, part_size, sample_per_part, trans_std)
 
+        self.num_offset_fcs = num_offset_fcs
         self.deform_fc_channels = deform_fc_channels
 
         if not no_trans:
-            self.offset_fc = nn.Sequential(
-                nn.Linear(self.out_size * self.out_size * self.out_channels,
-                          self.deform_fc_channels),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.deform_fc_channels,
-                          self.out_size * self.out_size * 2))
+            seq = []
+            ic = self.out_size * self.out_size * self.out_channels
+            for i in range(self.num_offset_fcs):
+                if i < self.num_offset_fcs - 1:
+                    oc = self.deform_fc_channels
+                else:
+                    oc = self.out_size * self.out_size * 2
+                seq.append(nn.Linear(ic, oc))
+                ic = oc
+                if i < self.num_offset_fcs - 1:
+                    seq.append(nn.ReLU(inplace=True))
+            self.offset_fc = nn.Sequential(*seq)
             self.offset_fc[-1].weight.data.zero_()
             self.offset_fc[-1].bias.data.zero_()
 
@@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
                  part_size=None,
                  sample_per_part=4,
                  trans_std=.0,
+                 num_offset_fcs=3,
+                 num_mask_fcs=2,
                  deform_fc_channels=1024):
         super(ModulatedDeformRoIPoolingPack, self).__init__(
             spatial_scale, out_size, out_channels, no_trans, group_size,
             part_size, sample_per_part, trans_std)
 
+        self.num_offset_fcs = num_offset_fcs
+        self.num_mask_fcs = num_mask_fcs
         self.deform_fc_channels = deform_fc_channels
 
         if not no_trans:
-            self.offset_fc = nn.Sequential(
-                nn.Linear(self.out_size * self.out_size * self.out_channels,
-                          self.deform_fc_channels),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.deform_fc_channels,
-                          self.out_size * self.out_size * 2))
+            offset_fc_seq = []
+            ic = self.out_size * self.out_size * self.out_channels
+            for i in range(self.num_offset_fcs):
+                if i < self.num_offset_fcs - 1:
+                    oc = self.deform_fc_channels
+                else:
+                    oc = self.out_size * self.out_size * 2
+                offset_fc_seq.append(nn.Linear(ic, oc))
+                ic = oc
+                if i < self.num_offset_fcs - 1:
+                    offset_fc_seq.append(nn.ReLU(inplace=True))
+            self.offset_fc = nn.Sequential(*offset_fc_seq)
             self.offset_fc[-1].weight.data.zero_()
             self.offset_fc[-1].bias.data.zero_()
-            self.mask_fc = nn.Sequential(
-                nn.Linear(self.out_size * self.out_size * self.out_channels,
-                          self.deform_fc_channels),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.deform_fc_channels,
-                          self.out_size * self.out_size * 1),
-                nn.Sigmoid())
-            self.mask_fc[2].weight.data.zero_()
-            self.mask_fc[2].bias.data.zero_()
+
+            mask_fc_seq = []
+            ic = self.out_size * self.out_size * self.out_channels
+            for i in range(self.num_mask_fcs):
+                if i < self.num_mask_fcs - 1:
+                    oc = self.deform_fc_channels
+                else:
+                    oc = self.out_size * self.out_size
+                mask_fc_seq.append(nn.Linear(ic, oc))
+                ic = oc
+                if i < self.num_mask_fcs - 1:
+                    mask_fc_seq.append(nn.ReLU(inplace=True))
+                else:
+                    mask_fc_seq.append(nn.Sigmoid())
+            self.mask_fc = nn.Sequential(*mask_fc_seq)
+            self.mask_fc[-2].weight.data.zero_()
+            self.mask_fc[-2].bias.data.zero_()
 
     def forward(self, data, rois):
         assert data.size(1) == self.out_channels