diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
index f4557f1d221ce8da4cd96961f81f40031111c340..90c0f235eae4d495be2b248654935d42184ce72a 100644
--- a/mmdet/models/necks/fpn.py
+++ b/mmdet/models/necks/fpn.py
@@ -16,6 +16,7 @@ class FPN(nn.Module):
                  start_level=0,
                  end_level=-1,
                  add_extra_convs=False,
+                 extra_convs_on_inputs=True,
                  normalize=None,
                  activation=None):
         super(FPN, self).__init__()
@@ -38,6 +39,7 @@ class FPN(nn.Module):
         self.start_level = start_level
         self.end_level = end_level
         self.add_extra_convs = add_extra_convs
+        self.extra_convs_on_inputs = extra_convs_on_inputs
 
         self.lateral_convs = nn.ModuleList()
         self.fpn_convs = nn.ModuleList()
@@ -64,16 +66,14 @@ class FPN(nn.Module):
             self.lateral_convs.append(l_conv)
             self.fpn_convs.append(fpn_conv)
 
-            # lvl_id = i - self.start_level
-            # setattr(self, 'lateral_conv{}'.format(lvl_id), l_conv)
-            # setattr(self, 'fpn_conv{}'.format(lvl_id), fpn_conv)
-
         # add extra conv layers (e.g., RetinaNet)
         extra_levels = num_outs - self.backbone_end_level + self.start_level
         if add_extra_convs and extra_levels >= 1:
             for i in range(extra_levels):
-                in_channels = (self.in_channels[self.backbone_end_level - 1]
-                               if i == 0 else out_channels)
+                if i == 0 and self.extra_convs_on_inputs:
+                    in_channels = self.in_channels[self.backbone_end_level - 1]
+                else:
+                    in_channels = out_channels
                 extra_fpn_conv = ConvModule(
                     in_channels,
                     out_channels,
@@ -121,8 +121,11 @@ class FPN(nn.Module):
                     outs.append(F.max_pool2d(outs[-1], 1, stride=2))
             # add conv layers on top of original feature maps (RetinaNet)
             else:
-                orig = inputs[self.backbone_end_level - 1]
-                outs.append(self.fpn_convs[used_backbone_levels](orig))
+                if self.extra_convs_on_inputs:
+                    orig = inputs[self.backbone_end_level - 1]
+                    outs.append(self.fpn_convs[used_backbone_levels](orig))
+                else:
+                    outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
                 for i in range(used_backbone_levels + 1, self.num_outs):
                     # BUG: we should add relu before each extra conv
                     outs.append(self.fpn_convs[i](outs[-1]))