diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py index f4557f1d221ce8da4cd96961f81f40031111c340..90c0f235eae4d495be2b248654935d42184ce72a 100644 --- a/mmdet/models/necks/fpn.py +++ b/mmdet/models/necks/fpn.py @@ -16,6 +16,7 @@ class FPN(nn.Module): start_level=0, end_level=-1, add_extra_convs=False, + extra_convs_on_inputs=True, normalize=None, activation=None): super(FPN, self).__init__() @@ -38,6 +39,7 @@ class FPN(nn.Module): self.start_level = start_level self.end_level = end_level self.add_extra_convs = add_extra_convs + self.extra_convs_on_inputs = extra_convs_on_inputs self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() @@ -64,16 +66,14 @@ class FPN(nn.Module): self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) - # lvl_id = i - self.start_level - # setattr(self, 'lateral_conv{}'.format(lvl_id), l_conv) - # setattr(self, 'fpn_conv{}'.format(lvl_id), fpn_conv) - # add extra conv layers (e.g., RetinaNet) extra_levels = num_outs - self.backbone_end_level + self.start_level if add_extra_convs and extra_levels >= 1: for i in range(extra_levels): - in_channels = (self.in_channels[self.backbone_end_level - 1] - if i == 0 else out_channels) + if i == 0 and self.extra_convs_on_inputs: + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels extra_fpn_conv = ConvModule( in_channels, out_channels, @@ -121,8 +121,11 @@ class FPN(nn.Module): outs.append(F.max_pool2d(outs[-1], 1, stride=2)) # add conv layers on top of original feature maps (RetinaNet) else: - orig = inputs[self.backbone_end_level - 1] - outs.append(self.fpn_convs[used_backbone_levels](orig)) + if self.extra_convs_on_inputs: + orig = inputs[self.backbone_end_level - 1] + outs.append(self.fpn_convs[used_backbone_levels](orig)) + else: + outs.append(self.fpn_convs[used_backbone_levels](outs[-1])) for i in range(used_backbone_levels + 1, self.num_outs): # BUG: we should add relu before each extra conv outs.append(self.fpn_convs[i](outs[-1]))