Skip to content
Snippets Groups Projects
Commit ff6c19f0 authored by Kai Chen's avatar Kai Chen
Browse files

add an argument to support building P6 from either C5 or P5

parent c17ed460
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,7 @@ class FPN(nn.Module): ...@@ -16,6 +16,7 @@ class FPN(nn.Module):
start_level=0, start_level=0,
end_level=-1, end_level=-1,
add_extra_convs=False, add_extra_convs=False,
extra_convs_on_inputs=True,
normalize=None, normalize=None,
activation=None): activation=None):
super(FPN, self).__init__() super(FPN, self).__init__()
...@@ -38,6 +39,7 @@ class FPN(nn.Module): ...@@ -38,6 +39,7 @@ class FPN(nn.Module):
self.start_level = start_level self.start_level = start_level
self.end_level = end_level self.end_level = end_level
self.add_extra_convs = add_extra_convs self.add_extra_convs = add_extra_convs
self.extra_convs_on_inputs = extra_convs_on_inputs
self.lateral_convs = nn.ModuleList() self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList()
...@@ -64,16 +66,14 @@ class FPN(nn.Module): ...@@ -64,16 +66,14 @@ class FPN(nn.Module):
self.lateral_convs.append(l_conv) self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv) self.fpn_convs.append(fpn_conv)
# lvl_id = i - self.start_level
# setattr(self, 'lateral_conv{}'.format(lvl_id), l_conv)
# setattr(self, 'fpn_conv{}'.format(lvl_id), fpn_conv)
# add extra conv layers (e.g., RetinaNet) # add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level extra_levels = num_outs - self.backbone_end_level + self.start_level
if add_extra_convs and extra_levels >= 1: if add_extra_convs and extra_levels >= 1:
for i in range(extra_levels): for i in range(extra_levels):
in_channels = (self.in_channels[self.backbone_end_level - 1] if i == 0 and self.extra_convs_on_inputs:
if i == 0 else out_channels) in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule( extra_fpn_conv = ConvModule(
in_channels, in_channels,
out_channels, out_channels,
...@@ -121,8 +121,11 @@ class FPN(nn.Module): ...@@ -121,8 +121,11 @@ class FPN(nn.Module):
outs.append(F.max_pool2d(outs[-1], 1, stride=2)) outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet) # add conv layers on top of original feature maps (RetinaNet)
else: else:
orig = inputs[self.backbone_end_level - 1] if self.extra_convs_on_inputs:
outs.append(self.fpn_convs[used_backbone_levels](orig)) orig = inputs[self.backbone_end_level - 1]
outs.append(self.fpn_convs[used_backbone_levels](orig))
else:
outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
for i in range(used_backbone_levels + 1, self.num_outs): for i in range(used_backbone_levels + 1, self.num_outs):
# BUG: we should add relu before each extra conv # BUG: we should add relu before each extra conv
outs.append(self.fpn_convs[i](outs[-1])) outs.append(self.fpn_convs[i](outs[-1]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment