From d1c71db791e2e8279e1a3214e8217a36495b22a0 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Wed, 27 Nov 2019 22:40:07 +0800
Subject: [PATCH] Reimplement "FreeAnchor: Learning to Match Anchors for Visual
 Object Detection" (#1391)

* add FreeAnchorRetinanet

* fix std

* refactor code

* add cfg

* fix bug in cfg

* add annotation and clean

* add readme

* fix isort error

* replace clip with torch.clamp

* fix init typo

* refactoring

* clean

* update benchmark

* rename and delete some var

* minor fix
---
 configs/free_anchor/README.md                 |  23 +++
 .../retinanet_free_anchor_r101_fpn_1x.py      | 124 ++++++++++++
 .../retinanet_free_anchor_r50_fpn_1x.py       | 124 ++++++++++++
 ...retinanet_free_anchor_x101-32x4d_fpn_1x.py | 126 ++++++++++++
 mmdet/models/anchor_heads/__init__.py         |   3 +-
 .../anchor_heads/free_anchor_retina_head.py   | 188 ++++++++++++++++++
 6 files changed, 587 insertions(+), 1 deletion(-)
 create mode 100644 configs/free_anchor/README.md
 create mode 100644 configs/free_anchor/retinanet_free_anchor_r101_fpn_1x.py
 create mode 100644 configs/free_anchor/retinanet_free_anchor_r50_fpn_1x.py
 create mode 100644 configs/free_anchor/retinanet_free_anchor_x101-32x4d_fpn_1x.py
 create mode 100644 mmdet/models/anchor_heads/free_anchor_retina_head.py

diff --git a/configs/free_anchor/README.md b/configs/free_anchor/README.md
new file mode 100644
index 0000000..5e2b912
--- /dev/null
+++ b/configs/free_anchor/README.md
@@ -0,0 +1,23 @@
+# FreeAnchor: Learning to Match Anchors for Visual Object Detection
+
+## Introduction
+
+```
+@inproceedings{zhang2019freeanchor,
+  title   =  {{FreeAnchor}: Learning to Match Anchors for Visual Object Detection},
+  author  =  {Zhang, Xiaosong and Wan, Fang and Liu, Chang and Ji, Rongrong and Ye, Qixiang},
+  booktitle =  {Neural Information Processing Systems},
+  year    =  {2019}
+}
+```
+
+## Results and Models
+
+| Backbone    | Style   | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
+|:---------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
+| R-50        | pytorch | 1x      | 4.7 | 0.322 | 12.0 | 38.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/free_anchor/retinanet_free_anchor_r50_fpn_1x_20190914-84db6585.pth) |
+| R-101       | pytorch | 1x      | 6.6 | 0.437 | 9.7 | 40.3 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/free_anchor/retinanet_free_anchor_r101_fpn_1x_20190914-c4e4db81.pth) |
+| X-101-32x4d | pytorch | 1x      | 7.8 | 0.640 | 8.4 | 42.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/free_anchor/retinanet_free_anchor_x101-32x4d_fpn_1x_20190914-eb73b804.pth) |
+
+**Notes:**
+- We use 8 GPUs with 2 images/GPU.
diff --git a/configs/free_anchor/retinanet_free_anchor_r101_fpn_1x.py b/configs/free_anchor/retinanet_free_anchor_r101_fpn_1x.py
new file mode 100644
index 0000000..f6df698
--- /dev/null
+++ b/configs/free_anchor/retinanet_free_anchor_r101_fpn_1x.py
@@ -0,0 +1,124 @@
+# model settings
+model = dict(
+    type='RetinaNet',
+    pretrained='torchvision://resnet101',
+    backbone=dict(
+        type='ResNet',
+        depth=101,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        start_level=1,
+        add_extra_convs=True,
+        num_outs=5),
+    bbox_head=dict(
+        type='FreeAnchorRetinaHead',
+        num_classes=81,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.75)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/retinanet_free_anchor_r101_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/free_anchor/retinanet_free_anchor_r50_fpn_1x.py b/configs/free_anchor/retinanet_free_anchor_r50_fpn_1x.py
new file mode 100644
index 0000000..e14d82b
--- /dev/null
+++ b/configs/free_anchor/retinanet_free_anchor_r50_fpn_1x.py
@@ -0,0 +1,124 @@
+# model settings
+model = dict(
+    type='RetinaNet',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        start_level=1,
+        add_extra_convs=True,
+        num_outs=5),
+    bbox_head=dict(
+        type='FreeAnchorRetinaHead',
+        num_classes=81,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.75)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/retinanet_free_anchor_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/free_anchor/retinanet_free_anchor_x101-32x4d_fpn_1x.py b/configs/free_anchor/retinanet_free_anchor_x101-32x4d_fpn_1x.py
new file mode 100644
index 0000000..ee35b35
--- /dev/null
+++ b/configs/free_anchor/retinanet_free_anchor_x101-32x4d_fpn_1x.py
@@ -0,0 +1,126 @@
+# model settings
+model = dict(
+    type='RetinaNet',
+    pretrained='open-mmlab://resnext101_32x4d',
+    backbone=dict(
+        type='ResNeXt',
+        depth=101,
+        groups=32,
+        base_width=4,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        start_level=1,
+        add_extra_convs=True,
+        num_outs=5),
+    bbox_head=dict(
+        type='FreeAnchorRetinaHead',
+        num_classes=81,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.75)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/retinanet_free_anchor_x101-32x4d_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/anchor_heads/__init__.py b/mmdet/models/anchor_heads/__init__.py
index 22c859f..54db861 100644
--- a/mmdet/models/anchor_heads/__init__.py
+++ b/mmdet/models/anchor_heads/__init__.py
@@ -1,6 +1,7 @@
 from .anchor_head import AnchorHead
 from .fcos_head import FCOSHead
 from .fovea_head import FoveaHead
+from .free_anchor_retina_head import FreeAnchorRetinaHead
 from .ga_retina_head import GARetinaHead
 from .ga_rpn_head import GARPNHead
 from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
@@ -12,5 +13,5 @@ from .ssd_head import SSDHead
 __all__ = [
     'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
     'GARPNHead', 'RetinaHead', 'GARetinaHead', 'SSDHead', 'FCOSHead',
-    'RepPointsHead', 'FoveaHead'
+    'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead'
 ]
diff --git a/mmdet/models/anchor_heads/free_anchor_retina_head.py b/mmdet/models/anchor_heads/free_anchor_retina_head.py
new file mode 100644
index 0000000..3179aad
--- /dev/null
+++ b/mmdet/models/anchor_heads/free_anchor_retina_head.py
@@ -0,0 +1,188 @@
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import bbox2delta, bbox_overlaps, delta2bbox
+from ..registry import HEADS
+from .retina_head import RetinaHead
+
+
+@HEADS.register_module
+class FreeAnchorRetinaHead(RetinaHead):
+
+    def __init__(self,
+                 num_classes,
+                 in_channels,
+                 stacked_convs=4,
+                 octave_base_scale=4,
+                 scales_per_octave=3,
+                 conv_cfg=None,
+                 norm_cfg=None,
+                 pre_anchor_topk=50,
+                 bbox_thr=0.6,
+                 gamma=2.0,
+                 alpha=0.5,
+                 **kwargs):
+        super(FreeAnchorRetinaHead,
+              self).__init__(num_classes, in_channels, stacked_convs,
+                             octave_base_scale, scales_per_octave, conv_cfg,
+                             norm_cfg, **kwargs)
+
+        self.pre_anchor_topk = pre_anchor_topk
+        self.bbox_thr = bbox_thr
+        self.gamma = gamma
+        self.alpha = alpha
+
+    def loss(self,
+             cls_scores,
+             bbox_preds,
+             gt_bboxes,
+             gt_labels,
+             img_metas,
+             cfg,
+             gt_bboxes_ignore=None):
+        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+        assert len(featmap_sizes) == len(self.anchor_generators)
+
+        anchor_list, _ = self.get_anchors(featmap_sizes, img_metas)
+        anchors = [torch.cat(anchor) for anchor in anchor_list]
+
+        # concatenate each level
+        cls_scores = [
+            cls.permute(0, 2, 3,
+                        1).reshape(cls.size(0), -1, self.cls_out_channels)
+            for cls in cls_scores
+        ]
+        bbox_preds = [
+            bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
+            for bbox_pred in bbox_preds
+        ]
+        cls_scores = torch.cat(cls_scores, dim=1)
+        bbox_preds = torch.cat(bbox_preds, dim=1)
+
+        cls_prob = torch.sigmoid(cls_scores)
+        box_prob = []
+        num_pos = 0
+        positive_losses = []
+        for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_,
+                bbox_preds_) in enumerate(
+                    zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)):
+            gt_labels_ -= 1
+
+            with torch.no_grad():
+                # box_localization: a_{j}^{loc}, shape: [j, 4]
+                pred_boxes = delta2bbox(anchors_, bbox_preds_,
+                                        self.target_means, self.target_stds)
+
+                # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
+                object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes)
+
+                # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
+                t1 = self.bbox_thr
+                t2 = object_box_iou.max(
+                    dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)
+                object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
+                    min=0, max=1)
+
+                # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
+                num_obj = gt_labels_.size(0)
+                indices = torch.stack(
+                    [torch.arange(num_obj).type_as(gt_labels_), gt_labels_],
+                    dim=0)
+                object_cls_box_prob = torch.sparse_coo_tensor(
+                    indices, object_box_prob)
+
+                # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
+                """
+                from "start" to "end" implement:
+                image_box_iou = torch.sparse.max(object_cls_box_prob,
+                                                 dim=0).t()
+
+                """
+                # start
+                box_cls_prob = torch.sparse.sum(
+                    object_cls_box_prob, dim=0).to_dense()
+
+                indices = torch.nonzero(box_cls_prob).t_()
+                if indices.numel() == 0:
+                    image_box_prob = torch.zeros(
+                        anchors_.size(0),
+                        self.cls_out_channels).type_as(object_box_prob)
+                else:
+                    nonzero_box_prob = torch.where(
+                        (gt_labels_.unsqueeze(dim=-1) == indices[0]),
+                        object_box_prob[:, indices[1]],
+                        torch.tensor(
+                            [0]).type_as(object_box_prob)).max(dim=0).values
+
+                    # upmap to shape [j, c]
+                    image_box_prob = torch.sparse_coo_tensor(
+                        indices.flip([0]),
+                        nonzero_box_prob,
+                        size=(anchors_.size(0),
+                              self.cls_out_channels)).to_dense()
+                # end
+
+                box_prob.append(image_box_prob)
+
+            # construct bags for objects
+            match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_)
+            _, matched = torch.topk(
+                match_quality_matrix,
+                self.pre_anchor_topk,
+                dim=1,
+                sorted=False)
+            del match_quality_matrix
+
+            # matched_cls_prob: P_{ij}^{cls}
+            matched_cls_prob = torch.gather(
+                cls_prob_[matched], 2,
+                gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
+                                                 1)).squeeze(2)
+
+            # matched_box_prob: P_{ij}^{loc}
+            matched_anchors = anchors_[matched]
+            matched_object_targets = bbox2delta(
+                matched_anchors,
+                gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors),
+                self.target_means, self.target_stds)
+            loss_bbox = self.loss_bbox(
+                bbox_preds_[matched],
+                matched_object_targets,
+                reduction_override='none').sum(-1)
+            matched_box_prob = torch.exp(-loss_bbox)
+
+            # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
+            num_pos += len(gt_bboxes_)
+            positive_losses.append(
+                self.positive_bag_loss(matched_cls_prob, matched_box_prob))
+        positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)
+
+        # box_prob: P{a_{j} \in A_{+}}
+        box_prob = torch.stack(box_prob, dim=0)
+
+        # negative_loss:
+        # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
+        negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
+            1, num_pos * self.pre_anchor_topk)
+
+        losses = {
+            'positive_bag_loss': positive_loss,
+            'negative_bag_loss': negative_loss
+        }
+        return losses
+
+    def positive_bag_loss(self, matched_cls_prob, matched_box_prob):
+        # bag_prob = Mean-max(matched_prob)
+        matched_prob = matched_cls_prob * matched_box_prob
+        weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
+        weight /= weight.sum(dim=1).unsqueeze(dim=-1)
+        bag_prob = (weight * matched_prob).sum(dim=1)
+        # positive_bag_loss = -self.alpha * log(bag_prob)
+        return self.alpha * F.binary_cross_entropy(
+            bag_prob, torch.ones_like(bag_prob), reduction='none')
+
+    def negative_bag_loss(self, cls_prob, box_prob):
+        prob = cls_prob * (1 - box_prob)
+        negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
+            prob, torch.zeros_like(prob), reduction='none')
+        return (1 - self.alpha) * negative_bag_loss
-- 
GitLab