From 32afa752710bd4154a0de3f246d4ea446f34cf88 Mon Sep 17 00:00:00 2001
From: libuyu <libuyu_braiten@outlook.com>
Date: Tue, 28 May 2019 19:44:27 +0800
Subject: [PATCH] Code for "Gradient Harmonized Single-stage Detector" (#706)

* finish ghm in newest verion with AP 37.0

* add ghm config file

* reformat for PEP8

* reformat for flake8

* add documents for GHM and tensorize the params

* improve the docs

* add readme and update configs
---
 configs/ghm/README.md                    |  16 +++
 configs/ghm/retinanet_ghm_r50_fpn_1x.py  | 129 +++++++++++++++++
 mmdet/models/anchor_heads/anchor_head.py |   2 +-
 mmdet/models/losses/__init__.py          |   3 +-
 mmdet/models/losses/ghm_loss.py          | 172 +++++++++++++++++++++++
 5 files changed, 320 insertions(+), 2 deletions(-)
 create mode 100644 configs/ghm/README.md
 create mode 100644 configs/ghm/retinanet_ghm_r50_fpn_1x.py
 create mode 100644 mmdet/models/losses/ghm_loss.py

diff --git a/configs/ghm/README.md b/configs/ghm/README.md
new file mode 100644
index 0000000..6596595
--- /dev/null
+++ b/configs/ghm/README.md
@@ -0,0 +1,16 @@
+# Gradient Harmonized Single-stage Detector
+
+## Introduction
+
+```
+@inproceedings{li2019gradient,
+  title={Gradient Harmonized Single-stage Detector},
+  author={Li, Buyu and Liu, Yu and Wang, Xiaogang},
+  booktitle={AAAI Conference on Artificial Intelligence},
+  year={2019}
+}
+```
+
+## Results and Models
+
+To be benchmarked.
\ No newline at end of file
diff --git a/configs/ghm/retinanet_ghm_r50_fpn_1x.py b/configs/ghm/retinanet_ghm_r50_fpn_1x.py
new file mode 100644
index 0000000..bb3a654
--- /dev/null
+++ b/configs/ghm/retinanet_ghm_r50_fpn_1x.py
@@ -0,0 +1,129 @@
+# model settings
+model = dict(
+    type='RetinaNet',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        start_level=1,
+        add_extra_convs=True,
+        num_outs=5),
+    bbox_head=dict(
+        type='RetinaHead',
+        num_classes=81,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='GHMC',
+            bins=30,
+            momentum=0.75,
+            use_sigmoid=True,
+            loss_weight=1.0),
+        loss_bbox=dict(
+            type='GHMR',
+            mu=0.02,
+            bins=10,
+            momentum=0.7,
+            loss_weight=10.0)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=False,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=False,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/ghm'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index 6d68c20..1883e91 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -57,7 +57,7 @@ class AnchorHead(nn.Module):
         self.target_stds = target_stds
 
         self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
-        self.sampling = loss_cls['type'] not in ['FocalLoss']
+        self.sampling = loss_cls['type'] not in ['FocalLoss', 'GHMC']
         if self.use_sigmoid_cls:
             self.cls_out_channels = num_classes - 1
         else:
diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py
index a2d7b4a..817f4d2 100644
--- a/mmdet/models/losses/__init__.py
+++ b/mmdet/models/losses/__init__.py
@@ -1,10 +1,11 @@
 from .cross_entropy_loss import CrossEntropyLoss
 from .focal_loss import FocalLoss
 from .smooth_l1_loss import SmoothL1Loss
+from .ghm_loss import GHMC, GHMR
 from .balanced_l1_loss import BalancedL1Loss
 from .iou_loss import IoULoss
 
 __all__ = [
     'CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'BalancedL1Loss',
-    'IoULoss'
+    'IoULoss', 'GHMC', 'GHMR'
 ]
diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py
new file mode 100644
index 0000000..681cd0c
--- /dev/null
+++ b/mmdet/models/losses/ghm_loss.py
@@ -0,0 +1,172 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..registry import LOSSES
+
+
+def _expand_binary_labels(labels, label_weights, label_channels):
+    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
+    inds = torch.nonzero(labels >= 1).squeeze()
+    if inds.numel() > 0:
+        bin_labels[inds, labels[inds] - 1] = 1
+    bin_label_weights = label_weights.view(-1, 1).expand(
+        label_weights.size(0), label_channels)
+    return bin_labels, bin_label_weights
+
+
+@LOSSES.register_module
+class GHMC(nn.Module):
+    """GHM Classification Loss.
+
+    Details of the theorem can be viewed in the paper
+    "Gradient Harmonized Single-stage Detector".
+    https://arxiv.org/abs/1811.05181
+
+    Args:
+        bins (int): Number of the unit regions for distribution calculation.
+        momentum (float): The parameter for moving average.
+        use_sigmoid (bool): Can only be true for BCE based loss now.
+        loss_weight (float): The weight of the total GHM-C loss.
+    """
+    def __init__(
+            self,
+            bins=10,
+            momentum=0,
+            use_sigmoid=True,
+            loss_weight=1.0):
+        super(GHMC, self).__init__()
+        self.bins = bins
+        self.momentum = momentum
+        self.edges = torch.arange(bins + 1).float().cuda() / bins
+        self.edges[-1] += 1e-6
+        if momentum > 0:
+            self.acc_sum = torch.zeros(bins).cuda()
+        self.use_sigmoid = use_sigmoid
+        if not self.use_sigmoid:
+            raise NotImplementedError
+        self.loss_weight = loss_weight
+
+    def forward(self, pred, target, label_weight, *args, **kwargs):
+        """Calculate the GHM-C loss.
+
+        Args:
+            pred (float tensor of size [batch_num, class_num]):
+                The direct prediction of classification fc layer.
+            target (float tensor of size [batch_num, class_num]):
+                Binary class target for each sample.
+            label_weight (float tensor of size [batch_num, class_num]):
+                the value is 1 if the sample is valid and 0 if ignored.
+        Returns:
+            The gradient harmonized loss.
+        """
+        # the target should be binary class label
+        if pred.dim() != target.dim():
+            target, label_weight = _expand_binary_labels(
+                                    target, label_weight, pred.size(-1))
+        target, label_weight = target.float(), label_weight.float()
+        edges = self.edges
+        mmt = self.momentum
+        weights = torch.zeros_like(pred)
+
+        # gradient length
+        g = torch.abs(pred.sigmoid().detach() - target)
+
+        valid = label_weight > 0
+        tot = max(valid.float().sum().item(), 1.0)
+        n = 0  # n valid bins
+        for i in range(self.bins):
+            inds = (g >= edges[i]) & (g < edges[i+1]) & valid
+            num_in_bin = inds.sum().item()
+            if num_in_bin > 0:
+                if mmt > 0:
+                    self.acc_sum[i] = mmt * self.acc_sum[i] \
+                        + (1 - mmt) * num_in_bin
+                    weights[inds] = tot / self.acc_sum[i]
+                else:
+                    weights[inds] = tot / num_in_bin
+                n += 1
+        if n > 0:
+            weights = weights / n
+
+        loss = F.binary_cross_entropy_with_logits(
+            pred, target, weights, reduction='sum') / tot
+        return loss * self.loss_weight
+
+
+@LOSSES.register_module
+class GHMR(nn.Module):
+    """GHM Regression Loss.
+
+    Details of the theorem can be viewed in the paper
+    "Gradient Harmonized Single-stage Detector"
+    https://arxiv.org/abs/1811.05181
+
+    Args:
+        mu (float): The parameter for the Authentic Smooth L1 loss.
+        bins (int): Number of the unit regions for distribution calculation.
+        momentum (float): The parameter for moving average.
+        loss_weight (float): The weight of the total GHM-R loss.
+    """
+    def __init__(
+            self,
+            mu=0.02,
+            bins=10,
+            momentum=0,
+            loss_weight=1.0):
+        super(GHMR, self).__init__()
+        self.mu = mu
+        self.bins = bins
+        self.edges = torch.arange(bins + 1).float().cuda() / bins
+        self.edges[-1] = 1e3
+        self.momentum = momentum
+        if momentum > 0:
+            self.acc_sum = torch.zeros(bins).cuda()
+        self.loss_weight = loss_weight
+
+    def forward(self, pred, target, label_weight, avg_factor=None):
+        """Calculate the GHM-R loss.
+
+        Args:
+            pred (float tensor of size [batch_num, 4 (* class_num)]):
+                The prediction of box regression layer. Channel number can be 4
+                or 4 * class_num depending on whether it is class-agnostic.
+            target (float tensor of size [batch_num, 4 (* class_num)]):
+                The target regression values with the same size of pred.
+            label_weight (float tensor of size [batch_num, 4 (* class_num)]):
+                The weight of each sample, 0 if ignored.
+        Returns:
+            The gradient harmonized loss.
+        """
+        mu = self.mu
+        edges = self.edges
+        mmt = self.momentum
+
+        # ASL1 loss
+        diff = pred - target
+        loss = torch.sqrt(diff * diff + mu * mu) - mu
+
+        # gradient length
+        g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
+        weights = torch.zeros_like(g)
+
+        valid = label_weight > 0
+        tot = max(label_weight.float().sum().item(), 1.0)
+        n = 0  # n: valid bins
+        for i in range(self.bins):
+            inds = (g >= edges[i]) & (g < edges[i+1]) & valid
+            num_in_bin = inds.sum().item()
+            if num_in_bin > 0:
+                n += 1
+                if mmt > 0:
+                    self.acc_sum[i] = mmt * self.acc_sum[i] \
+                        + (1 - mmt) * num_in_bin
+                    weights[inds] = tot / self.acc_sum[i]
+                else:
+                    weights[inds] = tot / num_in_bin
+        if n > 0:
+            weights /= n
+
+        loss = loss * weights
+        loss = loss.sum() / tot
+        return loss * self.loss_weight
-- 
GitLab