Skip to content
Snippets Groups Projects
Unverified Commit d5a6b5dc authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #489 from hellock/fpn-param

Add an argument in FPN to support building P6 from either C5 or P5
parents c17ed460 ff6c19f0
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,7 @@ class FPN(nn.Module): ...@@ -16,6 +16,7 @@ class FPN(nn.Module):
start_level=0, start_level=0,
end_level=-1, end_level=-1,
add_extra_convs=False, add_extra_convs=False,
extra_convs_on_inputs=True,
normalize=None, normalize=None,
activation=None): activation=None):
super(FPN, self).__init__() super(FPN, self).__init__()
...@@ -38,6 +39,7 @@ class FPN(nn.Module): ...@@ -38,6 +39,7 @@ class FPN(nn.Module):
self.start_level = start_level self.start_level = start_level
self.end_level = end_level self.end_level = end_level
self.add_extra_convs = add_extra_convs self.add_extra_convs = add_extra_convs
self.extra_convs_on_inputs = extra_convs_on_inputs
self.lateral_convs = nn.ModuleList() self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList()
...@@ -64,16 +66,14 @@ class FPN(nn.Module): ...@@ -64,16 +66,14 @@ class FPN(nn.Module):
self.lateral_convs.append(l_conv) self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv) self.fpn_convs.append(fpn_conv)
# lvl_id = i - self.start_level
# setattr(self, 'lateral_conv{}'.format(lvl_id), l_conv)
# setattr(self, 'fpn_conv{}'.format(lvl_id), fpn_conv)
# add extra conv layers (e.g., RetinaNet) # add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level extra_levels = num_outs - self.backbone_end_level + self.start_level
if add_extra_convs and extra_levels >= 1: if add_extra_convs and extra_levels >= 1:
for i in range(extra_levels): for i in range(extra_levels):
in_channels = (self.in_channels[self.backbone_end_level - 1] if i == 0 and self.extra_convs_on_inputs:
if i == 0 else out_channels) in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule( extra_fpn_conv = ConvModule(
in_channels, in_channels,
out_channels, out_channels,
...@@ -121,8 +121,11 @@ class FPN(nn.Module): ...@@ -121,8 +121,11 @@ class FPN(nn.Module):
outs.append(F.max_pool2d(outs[-1], 1, stride=2)) outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet) # add conv layers on top of original feature maps (RetinaNet)
else: else:
orig = inputs[self.backbone_end_level - 1] if self.extra_convs_on_inputs:
outs.append(self.fpn_convs[used_backbone_levels](orig)) orig = inputs[self.backbone_end_level - 1]
outs.append(self.fpn_convs[used_backbone_levels](orig))
else:
outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
for i in range(used_backbone_levels + 1, self.num_outs): for i in range(used_backbone_levels + 1, self.num_outs):
# BUG: we should add relu before each extra conv # BUG: we should add relu before each extra conv
outs.append(self.fpn_convs[i](outs[-1])) outs.append(self.fpn_convs[i](outs[-1]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment