diff --git a/README.md b/README.md
index 9618cd9bb14cf718b3293cb701ceacb8cee33c6d..3d579f50a208fc155101458c1667e4edae63b6cf 100644
--- a/README.md
+++ b/README.md
@@ -69,6 +69,7 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md).
 | RepPoints          | ✓        | ✓        | ☐        | ✗        | ✓     |
 | Foveabox           | ✓        | ✓        | ☐        | ✗        | ✓     |
 | FreeAnchor         | ✓        | ✓        | ☐        | ✗        | ✓     |
+| NAS-FPN            | ✓        | ✓        | ☐        | ✗        | ✓     |
 
 Other features
 - [x] DCNv2
diff --git a/configs/nas_fpn/README.md b/configs/nas_fpn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f6a999048fd68da3ace17204725e5bbb83565cce
--- /dev/null
+++ b/configs/nas_fpn/README.md
@@ -0,0 +1,25 @@
+# NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object Detection
+
+## Introduction
+
+```
+@inproceedings{ghiasi2019fpn,
+  title={Nas-fpn: Learning scalable feature pyramid architecture for object detection},
+  author={Ghiasi, Golnaz and Lin, Tsung-Yi and Le, Quoc V},
+  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+  pages={7036--7045},
+  year={2019}
+}
+```
+
+## Results and Models
+
+We benchmark the new training schedule (crop training, large batch, unfrozen BN, 50 epochs) introduced in NAS-FPN. RetinaNet is used in the paper.
+
+| Backbone    | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
+|:-----------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
+| R-50-FPN    | 50e     | 12.8     | 0.513               | 15.3           | 37.0   | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/nas_fpn/retinanet_crop640_r50_fpn_50e_190824-4d75bfa0.pth) |
+| R-50-NASFPN | 50e     | 14.8     | 0.662               | 13.1           | 39.8   | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/nas_fpn/retinanet_crop640_r50_nasfpn_50e_20191225-b82d3a86.pth) |
+
+
+**Note**: We find that it is unstable to train NAS-FPN and there is a small chance that results can be 3% mAP lower.
diff --git a/configs/nas_fpn/retinanet_crop640_r50_fpn_50e.py b/configs/nas_fpn/retinanet_crop640_r50_fpn_50e.py
new file mode 100644
index 0000000000000000000000000000000000000000..1587d876632710903ace05bd18e74b9f43f6b638
--- /dev/null
+++ b/configs/nas_fpn/retinanet_crop640_r50_fpn_50e.py
@@ -0,0 +1,149 @@
+cudnn_benchmark = True
+# model settings
+norm_cfg = dict(type='BN', requires_grad=True)
+model = dict(
+    type='RetinaNet',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        start_level=1,
+        add_extra_convs=True,
+        relu_before_extra_convs=True,
+        no_norm_on_lateral=True,
+        norm_cfg=norm_cfg,
+        num_outs=5),
+    bbox_head=dict(
+        type='RetinaSepBNHead',
+        num_classes=81,
+        num_ins=5,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        norm_cfg=norm_cfg,
+        loss_cls=dict(
+            type='FocalLoss',
+            use_sigmoid=True,
+            gamma=2.0,
+            alpha=0.25,
+            loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(
+        type='Resize',
+        img_scale=(640, 640),
+        ratio_range=(0.8, 1.2),
+        keep_ratio=True),
+    dict(type='RandomCrop', crop_size=(640, 640)),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=(640, 640)),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(640, 640),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=64),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=8,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(
+    type='SGD',
+    lr=0.08,
+    momentum=0.9,
+    weight_decay=0.0001,
+    paramwise_options=dict(norm_decay_mult=0))
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=1000,
+    warmup_ratio=0.1,
+    step=[30, 40])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 50
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/retinanet_crop640_r50_fpn_50e'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/nas_fpn/retinanet_crop640_r50_nasfpn_50e.py b/configs/nas_fpn/retinanet_crop640_r50_nasfpn_50e.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a37042bbf0fae1a00899b44e8496a86795f40a
--- /dev/null
+++ b/configs/nas_fpn/retinanet_crop640_r50_nasfpn_50e.py
@@ -0,0 +1,148 @@
+cudnn_benchmark = True
+# model settings
+norm_cfg = dict(type='BN', requires_grad=True)
+model = dict(
+    type='RetinaNet',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch'),
+    neck=dict(
+        type='NASFPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5,
+        stack_times=7,
+        start_level=1,
+        add_extra_convs=True,
+        norm_cfg=norm_cfg),
+    bbox_head=dict(
+        type='RetinaSepBNHead',
+        num_classes=81,
+        num_ins=5,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        norm_cfg=norm_cfg,
+        loss_cls=dict(
+            type='FocalLoss',
+            use_sigmoid=True,
+            gamma=2.0,
+            alpha=0.25,
+            loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(
+        type='Resize',
+        img_scale=(640, 640),
+        ratio_range=(0.8, 1.2),
+        keep_ratio=True),
+    dict(type='RandomCrop', crop_size=(640, 640)),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=(640, 640)),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(640, 640),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=128),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=8,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(
+    type='SGD',
+    lr=0.08,
+    momentum=0.9,
+    weight_decay=0.0001,
+    paramwise_options=dict(norm_decay_mult=0))
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=1000,
+    warmup_ratio=0.1,
+    step=[30, 40])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 50
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/retinanet_crop640_r50_nasfpn_50e'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md
index 53b65ad5f06dfe69a910eb11e91e53a4a0c0ff3f..35314574ffd3ffaf3e500f68cc9fb8871763f0aa 100644
--- a/docs/MODEL_ZOO.md
+++ b/docs/MODEL_ZOO.md
@@ -277,6 +277,9 @@ Please refer to [Mask Scoring R-CNN](https://github.com/open-mmlab/mmdetection/b
 
 Please refer to [Rethinking ImageNet Pre-training](https://github.com/open-mmlab/mmdetection/blob/master/configs/scratch) for details.
 
+### NAS-FPN
+Please refer to [NAS-FPN](https://github.com/open-mmlab/mmdetection/blob/master/configs/nas_fpn) for details.
+
 ### Other datasets
 
 We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
diff --git a/mmdet/models/anchor_heads/__init__.py b/mmdet/models/anchor_heads/__init__.py
index 54db861fd9c122b8343aef12aa15f28d4ca7ee9c..c693c0fd209cf34eaf01a6ce5b28e40d883b6ea6 100644
--- a/mmdet/models/anchor_heads/__init__.py
+++ b/mmdet/models/anchor_heads/__init__.py
@@ -7,11 +7,12 @@ from .ga_rpn_head import GARPNHead
 from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
 from .reppoints_head import RepPointsHead
 from .retina_head import RetinaHead
+from .retina_sepbn_head import RetinaSepBNHead
 from .rpn_head import RPNHead
 from .ssd_head import SSDHead
 
 __all__ = [
     'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
-    'GARPNHead', 'RetinaHead', 'GARetinaHead', 'SSDHead', 'FCOSHead',
-    'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead'
+    'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
+    'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead'
 ]
diff --git a/mmdet/models/anchor_heads/retina_sepbn_head.py b/mmdet/models/anchor_heads/retina_sepbn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f076617908ee956e595abc26ba7b1921390ea74
--- /dev/null
+++ b/mmdet/models/anchor_heads/retina_sepbn_head.py
@@ -0,0 +1,105 @@
+import numpy as np
+import torch.nn as nn
+from mmcv.cnn import normal_init
+
+from ..registry import HEADS
+from ..utils import ConvModule, bias_init_with_prob
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module
+class RetinaSepBNHead(AnchorHead):
+    """"RetinaHead with separate BN.
+
+    In RetinaHead, conv/norm layers are shared across different FPN levels,
+    while in RetinaSepBNHead, conv layers are shared across different FPN
+    levels, but BN layers are separated.
+    """
+
+    def __init__(self,
+                 num_classes,
+                 num_ins,
+                 in_channels,
+                 stacked_convs=4,
+                 octave_base_scale=4,
+                 scales_per_octave=3,
+                 conv_cfg=None,
+                 norm_cfg=None,
+                 **kwargs):
+        self.stacked_convs = stacked_convs
+        self.octave_base_scale = octave_base_scale
+        self.scales_per_octave = scales_per_octave
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.num_ins = num_ins
+        octave_scales = np.array(
+            [2**(i / scales_per_octave) for i in range(scales_per_octave)])
+        anchor_scales = octave_scales * octave_base_scale
+        super(RetinaSepBNHead, self).__init__(
+            num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)
+
+    def _init_layers(self):
+        self.relu = nn.ReLU(inplace=True)
+        self.cls_convs = nn.ModuleList()
+        self.reg_convs = nn.ModuleList()
+        for i in range(self.num_ins):
+            cls_convs = nn.ModuleList()
+            reg_convs = nn.ModuleList()
+            for i in range(self.stacked_convs):
+                chn = self.in_channels if i == 0 else self.feat_channels
+                cls_convs.append(
+                    ConvModule(
+                        chn,
+                        self.feat_channels,
+                        3,
+                        stride=1,
+                        padding=1,
+                        conv_cfg=self.conv_cfg,
+                        norm_cfg=self.norm_cfg))
+                reg_convs.append(
+                    ConvModule(
+                        chn,
+                        self.feat_channels,
+                        3,
+                        stride=1,
+                        padding=1,
+                        conv_cfg=self.conv_cfg,
+                        norm_cfg=self.norm_cfg))
+            self.cls_convs.append(cls_convs)
+            self.reg_convs.append(reg_convs)
+        for i in range(self.stacked_convs):
+            for j in range(1, self.num_ins):
+                self.cls_convs[j][i].conv = self.cls_convs[0][i].conv
+                self.reg_convs[j][i].conv = self.reg_convs[0][i].conv
+        self.retina_cls = nn.Conv2d(
+            self.feat_channels,
+            self.num_anchors * self.cls_out_channels,
+            3,
+            padding=1)
+        self.retina_reg = nn.Conv2d(
+            self.feat_channels, self.num_anchors * 4, 3, padding=1)
+
+    def init_weights(self):
+        for m in self.cls_convs[0]:
+            normal_init(m.conv, std=0.01)
+        for m in self.reg_convs[0]:
+            normal_init(m.conv, std=0.01)
+        bias_cls = bias_init_with_prob(0.01)
+        normal_init(self.retina_cls, std=0.01, bias=bias_cls)
+        normal_init(self.retina_reg, std=0.01)
+
+    def forward(self, feats):
+        cls_scores = []
+        bbox_preds = []
+        for i, x in enumerate(feats):
+            cls_feat = feats[i]
+            reg_feat = feats[i]
+            for cls_conv in self.cls_convs[i]:
+                cls_feat = cls_conv(cls_feat)
+            for reg_conv in self.reg_convs[i]:
+                reg_feat = reg_conv(reg_feat)
+            cls_score = self.retina_cls(cls_feat)
+            bbox_pred = self.retina_reg(reg_feat)
+            cls_scores.append(cls_score)
+            bbox_preds.append(bbox_pred)
+        return cls_scores, bbox_preds
diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py
index 6b26e5fee369f5d6578e8f4df9c4ff323d3510c6..fa5740443fa7d886878014ff55766cc2f60ac944 100644
--- a/mmdet/models/necks/__init__.py
+++ b/mmdet/models/necks/__init__.py
@@ -1,5 +1,6 @@
 from .bfp import BFP
 from .fpn import FPN
 from .hrfpn import HRFPN
+from .nas_fpn import NASFPN
 
-__all__ = ['FPN', 'BFP', 'HRFPN']
+__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN']
diff --git a/mmdet/models/necks/nas_fpn.py b/mmdet/models/necks/nas_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0a6898374b40bdca97d1f57918a169465e4a00b
--- /dev/null
+++ b/mmdet/models/necks/nas_fpn.py
@@ -0,0 +1,186 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import caffe2_xavier_init
+
+from ..registry import NECKS
+from ..utils import ConvModule
+
+
+class MergingCell(nn.Module):
+
+    def __init__(self, channels=256, with_conv=True, norm_cfg=None):
+        super(MergingCell, self).__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv_out = ConvModule(
+                channels,
+                channels,
+                3,
+                padding=1,
+                norm_cfg=norm_cfg,
+                order=('act', 'conv', 'norm'))
+
+    def _binary_op(self, x1, x2):
+        raise NotImplementedError
+
+    def _resize(self, x, size):
+        if x.shape[-2:] == size:
+            return x
+        elif x.shape[-2:] < size:
+            return F.interpolate(x, size=size, mode='nearest')
+        else:
+            assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
+            kernel_size = x.shape[-1] // size[-1]
+            x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
+            return x
+
+    def forward(self, x1, x2, out_size):
+        assert x1.shape[:2] == x2.shape[:2]
+        assert len(out_size) == 2
+
+        x1 = self._resize(x1, out_size)
+        x2 = self._resize(x2, out_size)
+
+        x = self._binary_op(x1, x2)
+        if self.with_conv:
+            x = self.conv_out(x)
+        return x
+
+
+class SumCell(MergingCell):
+
+    def _binary_op(self, x1, x2):
+        return x1 + x2
+
+
+class GPCell(MergingCell):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
+
+    def _binary_op(self, x1, x2):
+        x2_att = self.global_pool(x2).sigmoid()
+        return x2 + x2_att * x1
+
+
+@NECKS.register_module
+class NASFPN(nn.Module):
+    """NAS-FPN.
+
+    NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object
+    Detection. (https://arxiv.org/abs/1904.07392)
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 num_outs,
+                 stack_times,
+                 start_level=0,
+                 end_level=-1,
+                 add_extra_convs=False,
+                 norm_cfg=None):
+        super(NASFPN, self).__init__()
+        assert isinstance(in_channels, list)
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.num_ins = len(in_channels)  # num of input feature levels
+        self.num_outs = num_outs  # num of output feature levels
+        self.stack_times = stack_times
+        self.norm_cfg = norm_cfg
+
+        if end_level == -1:
+            self.backbone_end_level = self.num_ins
+            assert num_outs >= self.num_ins - start_level
+        else:
+            # if end_level < inputs, no extra level is allowed
+            self.backbone_end_level = end_level
+            assert end_level <= len(in_channels)
+            assert num_outs == end_level - start_level
+        self.start_level = start_level
+        self.end_level = end_level
+        self.add_extra_convs = add_extra_convs
+
+        # add lateral connections
+        self.lateral_convs = nn.ModuleList()
+        for i in range(self.start_level, self.backbone_end_level):
+            l_conv = ConvModule(
+                in_channels[i],
+                out_channels,
+                1,
+                norm_cfg=norm_cfg,
+                activation=None)
+            self.lateral_convs.append(l_conv)
+
+        # add extra downsample layers (stride-2 pooling or conv)
+        extra_levels = num_outs - self.backbone_end_level + self.start_level
+        self.extra_downsamples = nn.ModuleList()
+        for i in range(extra_levels):
+            extra_conv = ConvModule(
+                out_channels,
+                out_channels,
+                1,
+                norm_cfg=norm_cfg,
+                activation=None)
+            self.extra_downsamples.append(
+                nn.Sequential(extra_conv, nn.MaxPool2d(2, 2)))
+
+        # add NAS FPN connections
+        self.fpn_stages = nn.ModuleList()
+        for _ in range(self.stack_times):
+            stage = nn.ModuleDict()
+            # gp(p6, p4) -> p4_1
+            stage['gp_64_4'] = GPCell(out_channels, norm_cfg=norm_cfg)
+            # sum(p4_1, p4) -> p4_2
+            stage['sum_44_4'] = SumCell(out_channels, norm_cfg=norm_cfg)
+            # sum(p4_2, p3) -> p3_out
+            stage['sum_43_3'] = SumCell(out_channels, norm_cfg=norm_cfg)
+            # sum(p3_out, p4_2) -> p4_out
+            stage['sum_34_4'] = SumCell(out_channels, norm_cfg=norm_cfg)
+            # sum(p5, gp(p4_out, p3_out)) -> p5_out
+            stage['gp_43_5'] = GPCell(with_conv=False)
+            stage['sum_55_5'] = SumCell(out_channels, norm_cfg=norm_cfg)
+            # sum(p7, gp(p5_out, p4_2)) -> p7_out
+            stage['gp_54_7'] = GPCell(with_conv=False)
+            stage['sum_77_7'] = SumCell(out_channels, norm_cfg=norm_cfg)
+            # gp(p7_out, p5_out) -> p6_out
+            stage['gp_75_6'] = GPCell(out_channels, norm_cfg=norm_cfg)
+            self.fpn_stages.append(stage)
+
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                caffe2_xavier_init(m)
+
+    def forward(self, inputs):
+        # build P3-P5
+        feats = [
+            lateral_conv(inputs[i + self.start_level])
+            for i, lateral_conv in enumerate(self.lateral_convs)
+        ]
+        # build P6-P7 on top of P5
+        for downsample in self.extra_downsamples:
+            feats.append(downsample(feats[-1]))
+
+        p3, p4, p5, p6, p7 = feats
+
+        for stage in self.fpn_stages:
+            # gp(p6, p4) -> p4_1
+            p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])
+            # sum(p4_1, p4) -> p4_2
+            p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])
+            # sum(p4_2, p3) -> p3_out
+            p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])
+            # sum(p3_out, p4_2) -> p4_out
+            p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])
+            # sum(p5, gp(p4_out, p3_out)) -> p5_out
+            p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])
+            p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])
+            # sum(p7, gp(p5_out, p4_2)) -> p7_out
+            p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])
+            p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])
+            # gp(p7_out, p5_out) -> p6_out
+            p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])
+
+        return p3, p4, p5, p6, p7