diff --git a/README.md b/README.md index 9618cd9bb14cf718b3293cb701ceacb8cee33c6d..3d579f50a208fc155101458c1667e4edae63b6cf 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md). | RepPoints | ✓ | ✓ | ☠| ✗ | ✓ | | Foveabox | ✓ | ✓ | ☠| ✗ | ✓ | | FreeAnchor | ✓ | ✓ | ☠| ✗ | ✓ | +| NAS-FPN | ✓ | ✓ | ☠| ✗ | ✓ | Other features - [x] DCNv2 diff --git a/configs/nas_fpn/README.md b/configs/nas_fpn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f6a999048fd68da3ace17204725e5bbb83565cce --- /dev/null +++ b/configs/nas_fpn/README.md @@ -0,0 +1,25 @@ +# NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object Detection + +## Introduction + +``` +@inproceedings{ghiasi2019fpn, + title={Nas-fpn: Learning scalable feature pyramid architecture for object detection}, + author={Ghiasi, Golnaz and Lin, Tsung-Yi and Le, Quoc V}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={7036--7045}, + year={2019} +} +``` + +## Results and Models + +We benchmark the new training schedule (crop training, large batch, unfrozen BN, 50 epochs) introduced in NAS-FPN. RetinaNet is used in the paper. + +| Backbone | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download | +|:-----------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:| +| R-50-FPN | 50e | 12.8 | 0.513 | 15.3 | 37.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/nas_fpn/retinanet_crop640_r50_fpn_50e_190824-4d75bfa0.pth) | +| R-50-NASFPN | 50e | 14.8 | 0.662 | 13.1 | 39.8 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/nas_fpn/retinanet_crop640_r50_nasfpn_50e_20191225-b82d3a86.pth) | + + +**Note**: We find that it is unstable to train NAS-FPN and there is a small chance that results can be 3% mAP lower. diff --git a/configs/nas_fpn/retinanet_crop640_r50_fpn_50e.py b/configs/nas_fpn/retinanet_crop640_r50_fpn_50e.py new file mode 100644 index 0000000000000000000000000000000000000000..1587d876632710903ace05bd18e74b9f43f6b638 --- /dev/null +++ b/configs/nas_fpn/retinanet_crop640_r50_fpn_50e.py @@ -0,0 +1,149 @@ +cudnn_benchmark = True +# model settings +norm_cfg = dict(type='BN', requires_grad=True) +model = dict( + type='RetinaNet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + relu_before_extra_convs=True, + no_norm_on_lateral=True, + norm_cfg=norm_cfg, + num_outs=5), + bbox_head=dict( + type='RetinaSepBNHead', + num_classes=81, + num_ins=5, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=4, + scales_per_octave=3, + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[8, 16, 32, 64, 128], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + norm_cfg=norm_cfg, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=(640, 640), + ratio_range=(0.8, 1.2), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=(640, 640)), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(640, 640), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=64), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +# optimizer +optimizer = dict( + type='SGD', + lr=0.08, + momentum=0.9, + weight_decay=0.0001, + paramwise_options=dict(norm_decay_mult=0)) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.1, + step=[30, 40]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 50 +device_ids = range(8) +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/retinanet_crop640_r50_fpn_50e' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/nas_fpn/retinanet_crop640_r50_nasfpn_50e.py b/configs/nas_fpn/retinanet_crop640_r50_nasfpn_50e.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a37042bbf0fae1a00899b44e8496a86795f40a --- /dev/null +++ b/configs/nas_fpn/retinanet_crop640_r50_nasfpn_50e.py @@ -0,0 +1,148 @@ +cudnn_benchmark = True +# model settings +norm_cfg = dict(type='BN', requires_grad=True) +model = dict( + type='RetinaNet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch'), + neck=dict( + type='NASFPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5, + stack_times=7, + start_level=1, + add_extra_convs=True, + norm_cfg=norm_cfg), + bbox_head=dict( + type='RetinaSepBNHead', + num_classes=81, + num_ins=5, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=4, + scales_per_octave=3, + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[8, 16, 32, 64, 128], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + norm_cfg=norm_cfg, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=(640, 640), + ratio_range=(0.8, 1.2), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=(640, 640)), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(640, 640), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=128), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +# optimizer +optimizer = dict( + type='SGD', + lr=0.08, + momentum=0.9, + weight_decay=0.0001, + paramwise_options=dict(norm_decay_mult=0)) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.1, + step=[30, 40]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 50 +device_ids = range(8) +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/retinanet_crop640_r50_nasfpn_50e' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index 53b65ad5f06dfe69a910eb11e91e53a4a0c0ff3f..35314574ffd3ffaf3e500f68cc9fb8871763f0aa 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -277,6 +277,9 @@ Please refer to [Mask Scoring R-CNN](https://github.com/open-mmlab/mmdetection/b Please refer to [Rethinking ImageNet Pre-training](https://github.com/open-mmlab/mmdetection/blob/master/configs/scratch) for details. +### NAS-FPN +Please refer to [NAS-FPN](https://github.com/open-mmlab/mmdetection/blob/master/configs/nas_fpn) for details. + ### Other datasets We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). diff --git a/mmdet/models/anchor_heads/__init__.py b/mmdet/models/anchor_heads/__init__.py index 54db861fd9c122b8343aef12aa15f28d4ca7ee9c..c693c0fd209cf34eaf01a6ce5b28e40d883b6ea6 100644 --- a/mmdet/models/anchor_heads/__init__.py +++ b/mmdet/models/anchor_heads/__init__.py @@ -7,11 +7,12 @@ from .ga_rpn_head import GARPNHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead from .reppoints_head import RepPointsHead from .retina_head import RetinaHead +from .retina_sepbn_head import RetinaSepBNHead from .rpn_head import RPNHead from .ssd_head import SSDHead __all__ = [ 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', - 'GARPNHead', 'RetinaHead', 'GARetinaHead', 'SSDHead', 'FCOSHead', - 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead' + 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead', + 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead' ] diff --git a/mmdet/models/anchor_heads/retina_sepbn_head.py b/mmdet/models/anchor_heads/retina_sepbn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0f076617908ee956e595abc26ba7b1921390ea74 --- /dev/null +++ b/mmdet/models/anchor_heads/retina_sepbn_head.py @@ -0,0 +1,105 @@ +import numpy as np +import torch.nn as nn +from mmcv.cnn import normal_init + +from ..registry import HEADS +from ..utils import ConvModule, bias_init_with_prob +from .anchor_head import AnchorHead + + +@HEADS.register_module +class RetinaSepBNHead(AnchorHead): + """"RetinaHead with separate BN. + + In RetinaHead, conv/norm layers are shared across different FPN levels, + while in RetinaSepBNHead, conv layers are shared across different FPN + levels, but BN layers are separated. + """ + + def __init__(self, + num_classes, + num_ins, + in_channels, + stacked_convs=4, + octave_base_scale=4, + scales_per_octave=3, + conv_cfg=None, + norm_cfg=None, + **kwargs): + self.stacked_convs = stacked_convs + self.octave_base_scale = octave_base_scale + self.scales_per_octave = scales_per_octave + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.num_ins = num_ins + octave_scales = np.array( + [2**(i / scales_per_octave) for i in range(scales_per_octave)]) + anchor_scales = octave_scales * octave_base_scale + super(RetinaSepBNHead, self).__init__( + num_classes, in_channels, anchor_scales=anchor_scales, **kwargs) + + def _init_layers(self): + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.num_ins): + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.cls_convs.append(cls_convs) + self.reg_convs.append(reg_convs) + for i in range(self.stacked_convs): + for j in range(1, self.num_ins): + self.cls_convs[j][i].conv = self.cls_convs[0][i].conv + self.reg_convs[j][i].conv = self.reg_convs[0][i].conv + self.retina_cls = nn.Conv2d( + self.feat_channels, + self.num_anchors * self.cls_out_channels, + 3, + padding=1) + self.retina_reg = nn.Conv2d( + self.feat_channels, self.num_anchors * 4, 3, padding=1) + + def init_weights(self): + for m in self.cls_convs[0]: + normal_init(m.conv, std=0.01) + for m in self.reg_convs[0]: + normal_init(m.conv, std=0.01) + bias_cls = bias_init_with_prob(0.01) + normal_init(self.retina_cls, std=0.01, bias=bias_cls) + normal_init(self.retina_reg, std=0.01) + + def forward(self, feats): + cls_scores = [] + bbox_preds = [] + for i, x in enumerate(feats): + cls_feat = feats[i] + reg_feat = feats[i] + for cls_conv in self.cls_convs[i]: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs[i]: + reg_feat = reg_conv(reg_feat) + cls_score = self.retina_cls(cls_feat) + bbox_pred = self.retina_reg(reg_feat) + cls_scores.append(cls_score) + bbox_preds.append(bbox_pred) + return cls_scores, bbox_preds diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py index 6b26e5fee369f5d6578e8f4df9c4ff323d3510c6..fa5740443fa7d886878014ff55766cc2f60ac944 100644 --- a/mmdet/models/necks/__init__.py +++ b/mmdet/models/necks/__init__.py @@ -1,5 +1,6 @@ from .bfp import BFP from .fpn import FPN from .hrfpn import HRFPN +from .nas_fpn import NASFPN -__all__ = ['FPN', 'BFP', 'HRFPN'] +__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN'] diff --git a/mmdet/models/necks/nas_fpn.py b/mmdet/models/necks/nas_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a6898374b40bdca97d1f57918a169465e4a00b --- /dev/null +++ b/mmdet/models/necks/nas_fpn.py @@ -0,0 +1,186 @@ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import caffe2_xavier_init + +from ..registry import NECKS +from ..utils import ConvModule + + +class MergingCell(nn.Module): + + def __init__(self, channels=256, with_conv=True, norm_cfg=None): + super(MergingCell, self).__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv_out = ConvModule( + channels, + channels, + 3, + padding=1, + norm_cfg=norm_cfg, + order=('act', 'conv', 'norm')) + + def _binary_op(self, x1, x2): + raise NotImplementedError + + def _resize(self, x, size): + if x.shape[-2:] == size: + return x + elif x.shape[-2:] < size: + return F.interpolate(x, size=size, mode='nearest') + else: + assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0 + kernel_size = x.shape[-1] // size[-1] + x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size) + return x + + def forward(self, x1, x2, out_size): + assert x1.shape[:2] == x2.shape[:2] + assert len(out_size) == 2 + + x1 = self._resize(x1, out_size) + x2 = self._resize(x2, out_size) + + x = self._binary_op(x1, x2) + if self.with_conv: + x = self.conv_out(x) + return x + + +class SumCell(MergingCell): + + def _binary_op(self, x1, x2): + return x1 + x2 + + +class GPCell(MergingCell): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + + def _binary_op(self, x1, x2): + x2_att = self.global_pool(x2).sigmoid() + return x2 + x2_att * x1 + + +@NECKS.register_module +class NASFPN(nn.Module): + """NAS-FPN. + + NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object + Detection. (https://arxiv.org/abs/1904.07392) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + stack_times, + start_level=0, + end_level=-1, + add_extra_convs=False, + norm_cfg=None): + super(NASFPN, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) # num of input feature levels + self.num_outs = num_outs # num of output feature levels + self.stack_times = stack_times + self.norm_cfg = norm_cfg + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + + # add lateral connections + self.lateral_convs = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + norm_cfg=norm_cfg, + activation=None) + self.lateral_convs.append(l_conv) + + # add extra downsample layers (stride-2 pooling or conv) + extra_levels = num_outs - self.backbone_end_level + self.start_level + self.extra_downsamples = nn.ModuleList() + for i in range(extra_levels): + extra_conv = ConvModule( + out_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + activation=None) + self.extra_downsamples.append( + nn.Sequential(extra_conv, nn.MaxPool2d(2, 2))) + + # add NAS FPN connections + self.fpn_stages = nn.ModuleList() + for _ in range(self.stack_times): + stage = nn.ModuleDict() + # gp(p6, p4) -> p4_1 + stage['gp_64_4'] = GPCell(out_channels, norm_cfg=norm_cfg) + # sum(p4_1, p4) -> p4_2 + stage['sum_44_4'] = SumCell(out_channels, norm_cfg=norm_cfg) + # sum(p4_2, p3) -> p3_out + stage['sum_43_3'] = SumCell(out_channels, norm_cfg=norm_cfg) + # sum(p3_out, p4_2) -> p4_out + stage['sum_34_4'] = SumCell(out_channels, norm_cfg=norm_cfg) + # sum(p5, gp(p4_out, p3_out)) -> p5_out + stage['gp_43_5'] = GPCell(with_conv=False) + stage['sum_55_5'] = SumCell(out_channels, norm_cfg=norm_cfg) + # sum(p7, gp(p5_out, p4_2)) -> p7_out + stage['gp_54_7'] = GPCell(with_conv=False) + stage['sum_77_7'] = SumCell(out_channels, norm_cfg=norm_cfg) + # gp(p7_out, p5_out) -> p6_out + stage['gp_75_6'] = GPCell(out_channels, norm_cfg=norm_cfg) + self.fpn_stages.append(stage) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + caffe2_xavier_init(m) + + def forward(self, inputs): + # build P3-P5 + feats = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # build P6-P7 on top of P5 + for downsample in self.extra_downsamples: + feats.append(downsample(feats[-1])) + + p3, p4, p5, p6, p7 = feats + + for stage in self.fpn_stages: + # gp(p6, p4) -> p4_1 + p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:]) + # sum(p4_1, p4) -> p4_2 + p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:]) + # sum(p4_2, p3) -> p3_out + p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:]) + # sum(p3_out, p4_2) -> p4_out + p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:]) + # sum(p5, gp(p4_out, p3_out)) -> p5_out + p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:]) + p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:]) + # sum(p7, gp(p5_out, p4_2)) -> p7_out + p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:]) + p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:]) + # gp(p7_out, p5_out) -> p6_out + p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:]) + + return p3, p4, p5, p6, p7