From b581e19ffc4c88390aab5471734e1772fca0c844 Mon Sep 17 00:00:00 2001
From: Jiangmiao Pang <pangjiangmiao@gmail.com>
Date: Sat, 25 May 2019 18:38:30 +0800
Subject: [PATCH] Code for CVPR 2019 Paper "Libra R-CNN: Towards Balanced
 Learning for Object Detection" (#687)

* Code for components of Libra R-CNN

* Configs and README for Libra R-CNN

* update bfp

* Update Model ZOO

* add comments in non-local

* fix shape

* update bfp

* update according to ck's comments

* update des

* update des

* fix loss

* fix according to ck's comments

* fix activation in non-local

* fix conv_mask in non-local

* fix conv_mask in non-local

* Remove outdated model urls

* refactoring for bfp

* change in_channels from list[int] to int

* refactoring for nonlocal

* udpate weight init of nonlocal

* minor fix

* update new model urls
---
 MODEL_ZOO.md                                  |   8 +-
 README.md                                     |   1 +
 configs/libra_rcnn/README.md                  |  26 +++
 .../libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py  | 144 ++++++++++++++
 .../libra_faster_rcnn_r101_fpn_1x.py          | 185 +++++++++++++++++
 .../libra_faster_rcnn_r50_fpn_1x.py           | 185 +++++++++++++++++
 .../libra_faster_rcnn_x101_64x4d_fpn_1x.py    | 187 ++++++++++++++++++
 .../libra_rcnn/libra_retinanet_r50_fpn_1x.py  | 141 +++++++++++++
 .../bbox/samplers/iou_balanced_neg_sampler.py | 123 +++++++++---
 mmdet/core/loss/__init__.py                   |   6 +-
 mmdet/core/loss/losses.py                     |  63 ++++--
 mmdet/models/losses/__init__.py               |   6 +-
 mmdet/models/losses/balanced_l1_loss.py       |  31 +++
 mmdet/models/necks/__init__.py                |   3 +-
 mmdet/models/necks/bfp.py                     | 102 ++++++++++
 mmdet/models/plugins/__init__.py              |   3 +
 mmdet/models/plugins/non_local.py             | 114 +++++++++++
 17 files changed, 1281 insertions(+), 47 deletions(-)
 create mode 100644 configs/libra_rcnn/README.md
 create mode 100644 configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py
 create mode 100644 configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py
 create mode 100644 configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py
 create mode 100644 configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py
 create mode 100644 configs/libra_rcnn/libra_retinanet_r50_fpn_1x.py
 create mode 100644 mmdet/models/losses/balanced_l1_loss.py
 create mode 100644 mmdet/models/necks/bfp.py
 create mode 100644 mmdet/models/plugins/__init__.py
 create mode 100644 mmdet/models/plugins/non_local.py

diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md
index 41480bd..6b5ad86 100644
--- a/MODEL_ZOO.md
+++ b/MODEL_ZOO.md
@@ -214,11 +214,14 @@ Please refer to [Weight Standardization](configs/gn+ws/README.md) for details.
 
 Please refer to [Deformable Convolutional Networks](configs/dcn/README.md) for details.
 
+### Libra R-CNN
+
+Please refer to [Libra R-CNN](configs/libra_rcnn/README.md) for details.
+
 ### Guided Anchoring
 
 Please refer to [Guided Anchoring](configs/guided_anchoring/README.md) for details.
 
-
 ## Comparison with Detectron and maskrcnn-benchmark
 
 We compare mmdetection with [Detectron](https://github.com/facebookresearch/Detectron)
@@ -454,6 +457,3 @@ and the main advantage is PyTorch itself. We also perform some memory optimizati
 
 Note that Caffe2 and PyTorch have different apis to obtain memory usage with different implementations.
 For all codebases, `nvidia-smi` shows a larger memory usage than the reported number in the above table.
-
-
-
diff --git a/README.md b/README.md
index 6dcd83d..1f9cdd3 100644
--- a/README.md
+++ b/README.md
@@ -93,6 +93,7 @@ Results and models are available in the [Model zoo](MODEL_ZOO.md).
 | RetinaNet          | ✓        | ✓        | ☐        | ✗        |
 | Hybrid Task Cascade| ✓        | ✓        | ☐        | ✗        |
 | FCOS               | ✓        | ✓        | ☐        | ✗        |
+| Libra R-CNN        | ✓        | ✓        | ☐        | ✗        |
 
 Other features
 - [x] DCNv2
diff --git a/configs/libra_rcnn/README.md b/configs/libra_rcnn/README.md
new file mode 100644
index 0000000..e70d61c
--- /dev/null
+++ b/configs/libra_rcnn/README.md
@@ -0,0 +1,26 @@
+# Libra R-CNN: Towards Balanced Learning for Object Detection
+
+## Introduction
+
+We provide config files to reproduce the results in the CVPR 2019 paper [Libra R-CNN](https://arxiv.org/pdf/1904.02701.pdf).
+
+```
+@inproceedings{pang2019libra,
+  title={Libra R-CNN: Towards Balanced Learning for Object Detection},
+  author={Pang, Jiangmiao and Chen, Kai and Shi, Jianping and Feng, Huajun and Ouyang, Wanli and Dahua Lin},
+  booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+  year={2019}
+}
+```
+
+## Results and models
+
+The results on COCO 2017val are shown in the below table. (results on test-dev are usually slightly higher than val)
+
+| Architecture | Backbone  | Style   | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
+|:---------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:-------:|:--------:|
+| Faster R-CNN | R-50-FPN        | pytorch | 1x | 4.2  | 0.375 | 12.0 | 38.6 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_faster_rcnn_r50_fpn_1x_20190525-c8c06833.pth) |
+| Fast R-CNN   | R-50-FPN        | pytorch | 1x | 3.7  | 0.272 | 16.3 | 38.5 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_fast_rcnn_r50_fpn_1x_20190525-a43f88b5.pth) |
+| Faster R-CNN | R-101-FPN       | pytorch | 1x | 6.0  | 0.495 | 10.4 | 40.3 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_faster_rcnn_r101_fpn_1x_20190525-94e94051.pth) |
+| Faster R-CNN | X-101-64x4d-FPN | pytorch | 1x | 10.1 | 1.050 | 6.8  | 42.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x_20190525-359c134a.pth) |
+| RetinaNet    | R-50-FPN        | pytorch | 1x | 3.7  | 0.328 | 11.8 | 37.7 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/libra_rcnn/libra_retinanet_r50_fpn_1x_20190525-ead2a6bb.pth) |
diff --git a/configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py b/configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py
new file mode 100644
index 0000000..30189a4
--- /dev/null
+++ b/configs/libra_rcnn/libra_fast_rcnn_r50_fpn_1x.py
@@ -0,0 +1,144 @@
+# model settings
+model = dict(
+    type='FastRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=[
+        dict(
+            type='FPN',
+            in_channels=[256, 512, 1024, 2048],
+            out_channels=256,
+            num_outs=5),
+        dict(
+            type='BFP',
+            in_channels=256,
+            num_levels=5,
+            refine_level=2,
+            refine_type='non_local')
+    ],
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(
+            type='BalancedL1Loss',
+            alpha=0.5,
+            gamma=1.5,
+            beta=1.0,
+            loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='CombinedSampler',
+            num=512,
+            pos_fraction=0.25,
+            add_gt_as_proposals=True,
+            pos_sampler=dict(type='InstanceBalancedPosSampler'),
+            neg_sampler=dict(
+                type='IoUBalancedNegSampler',
+                floor=-1,
+                floor_thr=0,
+                num_bins=3)),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=0,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        proposal_file=data_root +
+        'libra_proposals/rpn_r50_fpn_1x_train2017.pkl',
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        proposal_file=data_root + 'libra_proposals/rpn_r50_fpn_1x_val2017.pkl',
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        proposal_file=data_root + 'libra_proposals/rpn_r50_fpn_1x_val2017.pkl',
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/libra_fast_rcnn_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py b/configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py
new file mode 100644
index 0000000..3b7b2fb
--- /dev/null
+++ b/configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py
@@ -0,0 +1,185 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet101',
+    backbone=dict(
+        type='ResNet',
+        depth=101,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=[
+        dict(
+            type='FPN',
+            in_channels=[256, 512, 1024, 2048],
+            out_channels=256,
+            num_outs=5),
+        dict(
+            type='BFP',
+            in_channels=256,
+            num_levels=5,
+            refine_level=2,
+            refine_type='non_local')
+    ],
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(
+            type='BalancedL1Loss',
+            alpha=0.5,
+            gamma=1.5,
+            beta=1.0,
+            loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=5,
+            add_gt_as_proposals=False),
+        allowed_border=-1,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='CombinedSampler',
+            num=512,
+            pos_fraction=0.25,
+            add_gt_as_proposals=True,
+            pos_sampler=dict(type='InstanceBalancedPosSampler'),
+            neg_sampler=dict(
+                type='IoUBalancedNegSampler',
+                floor=-1,
+                floor_thr=0,
+                num_bins=3)),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/libra_faster_rcnn_r101_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py b/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py
new file mode 100644
index 0000000..cec4944
--- /dev/null
+++ b/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py
@@ -0,0 +1,185 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=[
+        dict(
+            type='FPN',
+            in_channels=[256, 512, 1024, 2048],
+            out_channels=256,
+            num_outs=5),
+        dict(
+            type='BFP',
+            in_channels=256,
+            num_levels=5,
+            refine_level=2,
+            refine_type='non_local')
+    ],
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(
+            type='BalancedL1Loss',
+            alpha=0.5,
+            gamma=1.5,
+            beta=1.0,
+            loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=5,
+            add_gt_as_proposals=False),
+        allowed_border=-1,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='CombinedSampler',
+            num=512,
+            pos_fraction=0.25,
+            add_gt_as_proposals=True,
+            pos_sampler=dict(type='InstanceBalancedPosSampler'),
+            neg_sampler=dict(
+                type='IoUBalancedNegSampler',
+                floor=-1,
+                floor_thr=0,
+                num_bins=3)),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/libra_faster_rcnn_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py b/configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py
new file mode 100644
index 0000000..3b4a8e8
--- /dev/null
+++ b/configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py
@@ -0,0 +1,187 @@
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='open-mmlab://resnext101_64x4d',
+    backbone=dict(
+        type='ResNeXt',
+        depth=101,
+        groups=64,
+        base_width=4,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=[
+        dict(
+            type='FPN',
+            in_channels=[256, 512, 1024, 2048],
+            out_channels=256,
+            num_outs=5),
+        dict(
+            type='BFP',
+            in_channels=256,
+            num_levels=5,
+            refine_level=2,
+            refine_type='non_local')
+    ],
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(
+            type='BalancedL1Loss',
+            alpha=0.5,
+            gamma=1.5,
+            beta=1.0,
+            loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=5,
+            add_gt_as_proposals=False),
+        allowed_border=-1,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='CombinedSampler',
+            num=512,
+            pos_fraction=0.25,
+            add_gt_as_proposals=True,
+            pos_sampler=dict(type='InstanceBalancedPosSampler'),
+            neg_sampler=dict(
+                type='IoUBalancedNegSampler',
+                floor=-1,
+                floor_thr=0,
+                num_bins=3)),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/libra_faster_rcnn_x101_64x4d_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/libra_rcnn/libra_retinanet_r50_fpn_1x.py b/configs/libra_rcnn/libra_retinanet_r50_fpn_1x.py
new file mode 100644
index 0000000..70aff37
--- /dev/null
+++ b/configs/libra_rcnn/libra_retinanet_r50_fpn_1x.py
@@ -0,0 +1,141 @@
+# model settings
+model = dict(
+    type='RetinaNet',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=[
+        dict(
+            type='FPN',
+            in_channels=[256, 512, 1024, 2048],
+            out_channels=256,
+            start_level=1,
+            extra_convs_on_inputs=True,
+            add_extra_convs=True,
+            num_outs=5),
+        dict(
+            type='BFP',
+            in_channels=256,
+            num_levels=5,
+            refine_level=1,
+            refine_type='non_local')
+    ],
+    bbox_head=dict(
+        type='RetinaHead',
+        num_classes=81,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='FocalLoss',
+            use_sigmoid=True,
+            gamma=2.0,
+            alpha=0.25,
+            loss_weight=1.0),
+        loss_bbox=dict(
+            type='BalancedL1Loss',
+            alpha=0.5,
+            gamma=1.5,
+            beta=0.11,
+            loss_weight=1.0)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    smoothl1_beta=0.11,
+    gamma=2.0,
+    alpha=0.25,
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=False,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=False,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/libra_retinanet_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
index 82537ee..62431d6 100644
--- a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+++ b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
@@ -5,19 +5,72 @@ from .random_sampler import RandomSampler
 
 
 class IoUBalancedNegSampler(RandomSampler):
+    """IoU Balanced Sampling
+
+    arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
+
+    Sampling proposals according to their IoU. `floor_fraction` of needed RoIs
+    are sampled from proposals whose IoU are lower than `floor_thr` randomly.
+    The others are sampled from proposals whose IoU are higher than
+    `floor_thr`. These proposals are sampled from some bins evenly, which are
+    split by `num_bins` via IoU evenly.
+
+    Args:
+        num (int): number of proposals.
+        pos_fraction (float): fraction of positive proposals.
+        floor_thr (float): threshold (minimum) IoU for IoU balanced sampling,
+            set to -1 if all using IoU balanced sampling.
+        floor_fraction (float): sampling fraction of proposals under floor_thr.
+        num_bins (int): number of bins in IoU balanced sampling.
+    """
 
     def __init__(self,
                  num,
                  pos_fraction,
-                 hard_thr=0.1,
-                 hard_fraction=0.5,
+                 floor_thr=-1,
+                 floor_fraction=0,
+                 num_bins=3,
                  **kwargs):
         super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
                                                     **kwargs)
-        assert hard_thr > 0
-        assert 0 < hard_fraction < 1
-        self.hard_thr = hard_thr
-        self.hard_fraction = hard_fraction
+        assert floor_thr >= 0 or floor_thr == -1
+        assert 0 <= floor_fraction <= 1
+        assert num_bins >= 1
+
+        self.floor_thr = floor_thr
+        self.floor_fraction = floor_fraction
+        self.num_bins = num_bins
+
+    def sample_via_interval(self, max_overlaps, full_set, num_expected):
+        max_iou = max_overlaps.max()
+        iou_interval = (max_iou - self.floor_thr) / self.num_bins
+        per_num_expected = int(num_expected / self.num_bins)
+
+        sampled_inds = []
+        for i in range(self.num_bins):
+            start_iou = self.floor_thr + i * iou_interval
+            end_iou = self.floor_thr + (i + 1) * iou_interval
+            tmp_set = set(
+                np.where(
+                    np.logical_and(max_overlaps >= start_iou,
+                                   max_overlaps < end_iou))[0])
+            tmp_inds = list(tmp_set & full_set)
+            if len(tmp_inds) > per_num_expected:
+                tmp_sampled_set = self.random_choice(tmp_inds,
+                                                     per_num_expected)
+            else:
+                tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
+            sampled_inds.append(tmp_sampled_set)
+
+        sampled_inds = np.concatenate(sampled_inds)
+        if len(sampled_inds) < num_expected:
+            num_extra = num_expected - len(sampled_inds)
+            extra_inds = np.array(list(full_set - set(sampled_inds)))
+            if len(extra_inds) > num_extra:
+                extra_inds = self.random_choice(extra_inds, num_extra)
+            sampled_inds = np.concatenate([sampled_inds, extra_inds])
+
+        return sampled_inds
 
     def _sample_neg(self, assign_result, num_expected, **kwargs):
         neg_inds = torch.nonzero(assign_result.gt_inds == 0)
@@ -29,28 +82,46 @@ class IoUBalancedNegSampler(RandomSampler):
             max_overlaps = assign_result.max_overlaps.cpu().numpy()
             # balance sampling for negative samples
             neg_set = set(neg_inds.cpu().numpy())
-            easy_set = set(
-                np.where(
-                    np.logical_and(max_overlaps >= 0,
-                                   max_overlaps < self.hard_thr))[0])
-            hard_set = set(np.where(max_overlaps >= self.hard_thr)[0])
-            easy_neg_inds = list(easy_set & neg_set)
-            hard_neg_inds = list(hard_set & neg_set)
-
-            num_expected_hard = int(num_expected * self.hard_fraction)
-            if len(hard_neg_inds) > num_expected_hard:
-                sampled_hard_inds = self.random_choice(hard_neg_inds,
-                                                       num_expected_hard)
+
+            if self.floor_thr > 0:
+                floor_set = set(
+                    np.where(
+                        np.logical_and(max_overlaps >= 0,
+                                       max_overlaps < self.floor_thr))[0])
+                iou_sampling_set = set(
+                    np.where(max_overlaps >= self.floor_thr)[0])
+            elif self.floor_thr == 0:
+                floor_set = set(np.where(max_overlaps == 0)[0])
+                iou_sampling_set = set(
+                    np.where(max_overlaps > self.floor_thr)[0])
+            else:
+                floor_set = set()
+                iou_sampling_set = set(
+                    np.where(max_overlaps > self.floor_thr)[0])
+
+            floor_neg_inds = list(floor_set & neg_set)
+            iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
+            num_expected_iou_sampling = int(num_expected *
+                                            (1 - self.floor_fraction))
+            if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
+                if self.num_bins >= 2:
+                    iou_sampled_inds = self.sample_via_interval(
+                        max_overlaps, set(iou_sampling_neg_inds),
+                        num_expected_iou_sampling)
+                else:
+                    iou_sampled_inds = self.random_choice(
+                        iou_sampling_neg_inds, num_expected_iou_sampling)
             else:
-                sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
-            num_expected_easy = num_expected - len(sampled_hard_inds)
-            if len(easy_neg_inds) > num_expected_easy:
-                sampled_easy_inds = self.random_choice(easy_neg_inds,
-                                                       num_expected_easy)
+                iou_sampled_inds = np.array(
+                    iou_sampling_neg_inds, dtype=np.int)
+            num_expected_floor = num_expected - len(iou_sampled_inds)
+            if len(floor_neg_inds) > num_expected_floor:
+                sampled_floor_inds = self.random_choice(
+                    floor_neg_inds, num_expected_floor)
             else:
-                sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
-            sampled_inds = np.concatenate((sampled_easy_inds,
-                                           sampled_hard_inds))
+                sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
+            sampled_inds = np.concatenate(
+                (sampled_floor_inds, iou_sampled_inds))
             if len(sampled_inds) < num_expected:
                 num_extra = num_expected - len(sampled_inds)
                 extra_inds = np.array(list(neg_set - set(sampled_inds)))
diff --git a/mmdet/core/loss/__init__.py b/mmdet/core/loss/__init__.py
index c73b221..ad7b21f 100644
--- a/mmdet/core/loss/__init__.py
+++ b/mmdet/core/loss/__init__.py
@@ -2,12 +2,14 @@ from .losses import (weighted_nll_loss, weighted_cross_entropy,
                      weighted_binary_cross_entropy, sigmoid_focal_loss,
                      py_sigmoid_focal_loss, weighted_sigmoid_focal_loss,
                      mask_cross_entropy, smooth_l1_loss, weighted_smoothl1,
-                     bounded_iou_loss, weighted_iou_loss, iou_loss, accuracy)
+                     balanced_l1_loss, weighted_balanced_l1_loss, iou_loss,
+                     bounded_iou_loss, weighted_iou_loss, accuracy)
 
 __all__ = [
     'weighted_nll_loss', 'weighted_cross_entropy',
     'weighted_binary_cross_entropy', 'sigmoid_focal_loss',
     'py_sigmoid_focal_loss', 'weighted_sigmoid_focal_loss',
     'mask_cross_entropy', 'smooth_l1_loss', 'weighted_smoothl1',
-    'bounded_iou_loss', 'weighted_iou_loss', 'iou_loss', 'accuracy'
+    'balanced_l1_loss', 'weighted_balanced_l1_loss', 'bounded_iou_loss',
+    'weighted_iou_loss', 'iou_loss', 'accuracy'
 ]
diff --git a/mmdet/core/loss/losses.py b/mmdet/core/loss/losses.py
index 6bb0954..8db7118 100644
--- a/mmdet/core/loss/losses.py
+++ b/mmdet/core/loss/losses.py
@@ -1,4 +1,5 @@
 # TODO merge naive and weighted loss.
+import numpy as np
 import torch
 import torch.nn.functional as F
 
@@ -44,8 +45,8 @@ def py_sigmoid_focal_loss(pred,
     pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
     weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
     weight = weight * pt.pow(gamma)
-    loss = F.binary_cross_entropy_with_logits(pred, target,
-                                              reduction='none') * weight
+    loss = F.binary_cross_entropy_with_logits(
+        pred, target, reduction='none') * weight
     reduction_enum = F._Reduction.get_enum(reduction)
     # none: 0, mean:1, sum: 2
     if reduction_enum == 0:
@@ -74,9 +75,8 @@ def mask_cross_entropy(pred, target, label):
     num_rois = pred.size()[0]
     inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
     pred_slice = pred[inds, label].squeeze(1)
-    return F.binary_cross_entropy_with_logits(pred_slice,
-                                              target,
-                                              reduction='mean')[None]
+    return F.binary_cross_entropy_with_logits(
+        pred_slice, target, reduction='mean')[None]
 
 
 def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
@@ -102,6 +102,47 @@ def weighted_smoothl1(pred, target, weight, beta=1.0, avg_factor=None):
     return torch.sum(loss * weight)[None] / avg_factor
 
 
+def balanced_l1_loss(pred,
+                     target,
+                     beta=1.0,
+                     alpha=0.5,
+                     gamma=1.5,
+                     reduction='none'):
+    assert beta > 0
+    assert pred.size() == target.size() and target.numel() > 0
+
+    diff = torch.abs(pred - target)
+    b = np.e**(gamma / alpha) - 1
+    loss = torch.where(
+        diff < beta, alpha / b *
+        (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
+        gamma * diff + gamma / b - alpha * beta)
+
+    reduction_enum = F._Reduction.get_enum(reduction)
+    # none: 0, elementwise_mean:1, sum: 2
+    if reduction_enum == 0:
+        return loss
+    elif reduction_enum == 1:
+        return loss.sum() / pred.numel()
+    elif reduction_enum == 2:
+        return loss.sum()
+
+    return loss
+
+
+def weighted_balanced_l1_loss(pred,
+                              target,
+                              weight,
+                              beta=1.0,
+                              alpha=0.5,
+                              gamma=1.5,
+                              avg_factor=None):
+    if avg_factor is None:
+        avg_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6
+    loss = balanced_l1_loss(pred, target, beta, alpha, gamma, reduction='none')
+    return torch.sum(loss * weight)[None] / avg_factor
+
+
 def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3, reduction='mean'):
     """Improving Object Localization with Fitness NMS and Bounded IoU Loss,
     https://arxiv.org/abs/1711.00164.
@@ -170,11 +211,8 @@ def weighted_iou_loss(pred,
         return (pred * weight).sum()[None] / avg_factor
 
     if style == 'bounded':
-        loss = bounded_iou_loss(pred[inds],
-                                target[inds],
-                                beta=beta,
-                                eps=eps,
-                                reduction='sum')
+        loss = bounded_iou_loss(
+            pred[inds], target[inds], beta=beta, eps=eps, reduction='sum')
     else:
         loss = iou_loss(pred[inds], target[inds], reduction='sum')
     loss = loss[None] / avg_factor
@@ -205,9 +243,8 @@ def _expand_binary_labels(labels, label_weights, label_channels):
     inds = torch.nonzero(labels >= 1).squeeze()
     if inds.numel() > 0:
         bin_labels[inds, labels[inds] - 1] = 1
-    bin_label_weights = label_weights.view(-1,
-                                           1).expand(label_weights.size(0),
-                                                     label_channels)
+    bin_label_weights = label_weights.view(-1, 1).expand(
+        label_weights.size(0), label_channels)
     return bin_labels, bin_label_weights
 
 
diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py
index 3b00245..a2d7b4a 100644
--- a/mmdet/models/losses/__init__.py
+++ b/mmdet/models/losses/__init__.py
@@ -1,6 +1,10 @@
 from .cross_entropy_loss import CrossEntropyLoss
 from .focal_loss import FocalLoss
 from .smooth_l1_loss import SmoothL1Loss
+from .balanced_l1_loss import BalancedL1Loss
 from .iou_loss import IoULoss
 
-__all__ = ['CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'IoULoss']
+__all__ = [
+    'CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'BalancedL1Loss',
+    'IoULoss'
+]
diff --git a/mmdet/models/losses/balanced_l1_loss.py b/mmdet/models/losses/balanced_l1_loss.py
new file mode 100644
index 0000000..7511e26
--- /dev/null
+++ b/mmdet/models/losses/balanced_l1_loss.py
@@ -0,0 +1,31 @@
+import torch.nn as nn
+from mmdet.core import weighted_balanced_l1_loss
+
+from ..registry import LOSSES
+
+
+@LOSSES.register_module
+class BalancedL1Loss(nn.Module):
+    """Balanced L1 Loss
+
+    arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
+    """
+
+    def __init__(self, alpha=0.5, gamma=1.5, beta=1.0, loss_weight=1.0):
+        super(BalancedL1Loss, self).__init__()
+        self.alpha = alpha
+        self.gamma = gamma
+        self.beta = beta
+        self.loss_weight = loss_weight
+
+    def forward(self, pred, target, weight, *args, **kwargs):
+        loss_bbox = self.loss_weight * weighted_balanced_l1_loss(
+            pred,
+            target,
+            weight,
+            alpha=self.alpha,
+            gamma=self.gamma,
+            beta=self.beta,
+            *args,
+            **kwargs)
+        return loss_bbox
diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py
index f88b47c..aa56c42 100644
--- a/mmdet/models/necks/__init__.py
+++ b/mmdet/models/necks/__init__.py
@@ -1,4 +1,5 @@
 from .fpn import FPN
+from .bfp import BFP
 from .hrfpn import HRFPN
 
-__all__ = ['FPN', 'HRFPN']
+__all__ = ['FPN', 'BFP', 'HRFPN']
diff --git a/mmdet/models/necks/bfp.py b/mmdet/models/necks/bfp.py
new file mode 100644
index 0000000..03aee10
--- /dev/null
+++ b/mmdet/models/necks/bfp.py
@@ -0,0 +1,102 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import xavier_init
+
+from ..plugins import NonLocal2D
+from ..registry import NECKS
+from ..utils import ConvModule
+
+
+@NECKS.register_module
+class BFP(nn.Module):
+    """BFP (Balanced Feature Pyrmamids)
+
+    BFP takes multi-level features as inputs and gather them into a single one,
+    then refine the gathered feature and scatter the refined results to
+    multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
+    https://arxiv.org/pdf/1904.02701.pdf for details.
+
+    Args:
+        in_channels (int): Number of input channels (feature maps of all levels
+            should have the same channels).
+        num_levels (int): Number of input feature levels.
+        conv_cfg (dict): The config dict for convolution layers.
+        norm_cfg (dict): The config dict for normalization layers.
+        refine_level (int): Index of integration and refine level of BSF in
+            multi-level features from bottom to top.
+        refine_type (str): Type of the refine op, currently support
+            [None, 'conv', 'non_local'].
+    """
+
+    def __init__(self,
+                 in_channels,
+                 num_levels,
+                 refine_level=2,
+                 refine_type=None,
+                 conv_cfg=None,
+                 norm_cfg=None):
+        super(BFP, self).__init__()
+        assert refine_type in [None, 'conv', 'non_local']
+
+        self.in_channels = in_channels
+        self.num_levels = num_levels
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+
+        self.refine_level = refine_level
+        self.refine_type = refine_type
+        assert 0 <= self.refine_level < self.num_levels
+
+        if self.refine_type == 'conv':
+            self.refine = ConvModule(
+                self.in_channels,
+                self.in_channels,
+                3,
+                padding=1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg)
+        elif self.refine_type == 'non_local':
+            self.refine = NonLocal2D(
+                self.in_channels,
+                reduction=1,
+                use_scale=False,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg)
+
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                xavier_init(m, distribution='uniform')
+
+    def forward(self, inputs):
+        assert len(inputs) == self.num_levels
+
+        # step 1: gather multi-level features by resize and average
+        feats = []
+        gather_size = inputs[self.refine_level].size()[2:]
+        for i in range(self.num_levels):
+            if i < self.refine_level:
+                gathered = F.adaptive_max_pool2d(
+                    inputs[i], output_size=gather_size)
+            else:
+                gathered = F.interpolate(
+                    inputs[i], size=gather_size, mode='nearest')
+            feats.append(gathered)
+
+        bsf = sum(feats) / len(feats)
+
+        # step 2: refine gathered features
+        if self.refine_type is not None:
+            bsf = self.refine(bsf)
+
+        # step 3: scatter refined features to multi-levels by a residual path
+        outs = []
+        for i in range(self.num_levels):
+            out_size = inputs[i].size()[2:]
+            if i < self.refine_level:
+                residual = F.interpolate(bsf, size=out_size, mode='nearest')
+            else:
+                residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
+            outs.append(residual + inputs[i])
+
+        return tuple(outs)
diff --git a/mmdet/models/plugins/__init__.py b/mmdet/models/plugins/__init__.py
new file mode 100644
index 0000000..87744df
--- /dev/null
+++ b/mmdet/models/plugins/__init__.py
@@ -0,0 +1,3 @@
+from .non_local import NonLocal2D
+
+__all__ = ['NonLocal2D']
diff --git a/mmdet/models/plugins/non_local.py b/mmdet/models/plugins/non_local.py
new file mode 100644
index 0000000..cbec7a4
--- /dev/null
+++ b/mmdet/models/plugins/non_local.py
@@ -0,0 +1,114 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import constant_init, normal_init
+
+from ..utils import ConvModule
+
+
+class NonLocal2D(nn.Module):
+    """Non-local module.
+
+    See https://arxiv.org/abs/1711.07971 for details.
+
+    Args:
+        in_channels (int): Channels of the input feature map.
+        reduction (int): Channel reduction ratio.
+        use_scale (bool): Whether to scale pairwise_weight by 1/inter_channels.
+        conv_cfg (dict): The config dict for convolution layers.
+            (only applicable to conv_out)
+        norm_cfg (dict): The config dict for normalization layers.
+            (only applicable to conv_out)
+        mode (str): Options are `embedded_gaussian` and `dot_product`.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 reduction=2,
+                 use_scale=True,
+                 conv_cfg=None,
+                 norm_cfg=None,
+                 mode='embedded_gaussian'):
+        super(NonLocal2D, self).__init__()
+        self.in_channels = in_channels
+        self.reduction = reduction
+        self.use_scale = use_scale
+        self.inter_channels = in_channels // reduction
+        self.mode = mode
+        assert mode in ['embedded_gaussian', 'dot_product']
+
+        # g, theta, phi are actually `nn.Conv2d`. Here we use ConvModule for
+        # potential usage.
+        self.g = ConvModule(
+            self.in_channels,
+            self.inter_channels,
+            kernel_size=1,
+            activation=None)
+        self.theta = ConvModule(
+            self.in_channels,
+            self.inter_channels,
+            kernel_size=1,
+            activation=None)
+        self.phi = ConvModule(
+            self.in_channels,
+            self.inter_channels,
+            kernel_size=1,
+            activation=None)
+        self.conv_out = ConvModule(
+            self.inter_channels,
+            self.in_channels,
+            kernel_size=1,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            activation=None)
+
+        self.init_weights()
+
+    def init_weights(self, std=0.01, zeros_init=True):
+        for m in [self.g, self.theta, self.phi]:
+            normal_init(m.conv, std=std)
+        if zeros_init:
+            constant_init(self.conv_out.conv, 0)
+        else:
+            normal_init(self.conv_out.conv, std=std)
+
+    def embedded_gaussian(self, theta_x, phi_x):
+        # pairwise_weight: [N, HxW, HxW]
+        pairwise_weight = torch.matmul(theta_x, phi_x)
+        if self.use_scale:
+            # theta_x.shape[-1] is `self.inter_channels`
+            pairwise_weight /= theta_x.shape[-1]**-0.5
+        pairwise_weight = pairwise_weight.softmax(dim=-1)
+        return pairwise_weight
+
+    def dot_product(self, theta_x, phi_x):
+        # pairwise_weight: [N, HxW, HxW]
+        pairwise_weight = torch.matmul(theta_x, phi_x)
+        pairwise_weight /= pairwise_weight.shape[-1]
+        return pairwise_weight
+
+    def forward(self, x):
+        n, _, h, w = x.shape
+
+        # g_x: [N, HxW, C]
+        g_x = self.g(x).view(n, self.inter_channels, -1)
+        g_x = g_x.permute(0, 2, 1)
+
+        # theta_x: [N, HxW, C]
+        theta_x = self.theta(x).view(n, self.inter_channels, -1)
+        theta_x = theta_x.permute(0, 2, 1)
+
+        # phi_x: [N, C, HxW]
+        phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+        pairwise_func = getattr(self, self.mode)
+        # pairwise_weight: [N, HxW, HxW]
+        pairwise_weight = pairwise_func(theta_x, phi_x)
+
+        # y: [N, HxW, C]
+        y = torch.matmul(pairwise_weight, g_x)
+        # y: [N, C, H, W]
+        y = y.permute(0, 2, 1).reshape(n, self.inter_channels, h, w)
+
+        output = x + self.conv_out(y)
+
+        return output
-- 
GitLab