diff --git a/configs/ghm/README.md b/configs/ghm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..659659541fbe74030c271f066d2e57d96527264d --- /dev/null +++ b/configs/ghm/README.md @@ -0,0 +1,16 @@ +# Gradient Harmonized Single-stage Detector + +## Introduction + +``` +@inproceedings{li2019gradient, + title={Gradient Harmonized Single-stage Detector}, + author={Li, Buyu and Liu, Yu and Wang, Xiaogang}, + booktitle={AAAI Conference on Artificial Intelligence}, + year={2019} +} +``` + +## Results and Models + +To be benchmarked. \ No newline at end of file diff --git a/configs/ghm/retinanet_ghm_r50_fpn_1x.py b/configs/ghm/retinanet_ghm_r50_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3a6548f70fc937c1a262156cea73f0d75a2622 --- /dev/null +++ b/configs/ghm/retinanet_ghm_r50_fpn_1x.py @@ -0,0 +1,129 @@ +# model settings +model = dict( + type='RetinaNet', + pretrained='modelzoo://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=4, + scales_per_octave=3, + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[8, 16, 32, 64, 128], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='GHMC', + bins=30, + momentum=0.75, + use_sigmoid=True, + loss_weight=1.0), + loss_bbox=dict( + type='GHMR', + mu=0.02, + bins=10, + momentum=0.7, + loss_weight=10.0))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=False, + with_crowd=False, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=False, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +device_ids = range(8) +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/ghm' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py index 6d68c204ccd26fc3a23a4a67da89c33af1e48226..1883e91a107175fd8dff6564ef1226db371382b4 100644 --- a/mmdet/models/anchor_heads/anchor_head.py +++ b/mmdet/models/anchor_heads/anchor_head.py @@ -57,7 +57,7 @@ class AnchorHead(nn.Module): self.target_stds = target_stds self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) - self.sampling = loss_cls['type'] not in ['FocalLoss'] + self.sampling = loss_cls['type'] not in ['FocalLoss', 'GHMC'] if self.use_sigmoid_cls: self.cls_out_channels = num_classes - 1 else: diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py index a2d7b4aadfdb25d6b2376398d3b4ab5fa35cab68..817f4d26da2567a2d31d2c3e4404dfc085d11969 100644 --- a/mmdet/models/losses/__init__.py +++ b/mmdet/models/losses/__init__.py @@ -1,10 +1,11 @@ from .cross_entropy_loss import CrossEntropyLoss from .focal_loss import FocalLoss from .smooth_l1_loss import SmoothL1Loss +from .ghm_loss import GHMC, GHMR from .balanced_l1_loss import BalancedL1Loss from .iou_loss import IoULoss __all__ = [ 'CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'BalancedL1Loss', - 'IoULoss' + 'IoULoss', 'GHMC', 'GHMR' ] diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..681cd0ce9054c4ab2e995d13cfc0f2ab28018ca7 --- /dev/null +++ b/mmdet/models/losses/ghm_loss.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..registry import LOSSES + + +def _expand_binary_labels(labels, label_weights, label_channels): + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + inds = torch.nonzero(labels >= 1).squeeze() + if inds.numel() > 0: + bin_labels[inds, labels[inds] - 1] = 1 + bin_label_weights = label_weights.view(-1, 1).expand( + label_weights.size(0), label_channels) + return bin_labels, bin_label_weights + + +@LOSSES.register_module +class GHMC(nn.Module): + """GHM Classification Loss. + + Details of the theorem can be viewed in the paper + "Gradient Harmonized Single-stage Detector". + https://arxiv.org/abs/1811.05181 + + Args: + bins (int): Number of the unit regions for distribution calculation. + momentum (float): The parameter for moving average. + use_sigmoid (bool): Can only be true for BCE based loss now. + loss_weight (float): The weight of the total GHM-C loss. + """ + def __init__( + self, + bins=10, + momentum=0, + use_sigmoid=True, + loss_weight=1.0): + super(GHMC, self).__init__() + self.bins = bins + self.momentum = momentum + self.edges = torch.arange(bins + 1).float().cuda() / bins + self.edges[-1] += 1e-6 + if momentum > 0: + self.acc_sum = torch.zeros(bins).cuda() + self.use_sigmoid = use_sigmoid + if not self.use_sigmoid: + raise NotImplementedError + self.loss_weight = loss_weight + + def forward(self, pred, target, label_weight, *args, **kwargs): + """Calculate the GHM-C loss. + + Args: + pred (float tensor of size [batch_num, class_num]): + The direct prediction of classification fc layer. + target (float tensor of size [batch_num, class_num]): + Binary class target for each sample. + label_weight (float tensor of size [batch_num, class_num]): + the value is 1 if the sample is valid and 0 if ignored. + Returns: + The gradient harmonized loss. + """ + # the target should be binary class label + if pred.dim() != target.dim(): + target, label_weight = _expand_binary_labels( + target, label_weight, pred.size(-1)) + target, label_weight = target.float(), label_weight.float() + edges = self.edges + mmt = self.momentum + weights = torch.zeros_like(pred) + + # gradient length + g = torch.abs(pred.sigmoid().detach() - target) + + valid = label_weight > 0 + tot = max(valid.float().sum().item(), 1.0) + n = 0 # n valid bins + for i in range(self.bins): + inds = (g >= edges[i]) & (g < edges[i+1]) & valid + num_in_bin = inds.sum().item() + if num_in_bin > 0: + if mmt > 0: + self.acc_sum[i] = mmt * self.acc_sum[i] \ + + (1 - mmt) * num_in_bin + weights[inds] = tot / self.acc_sum[i] + else: + weights[inds] = tot / num_in_bin + n += 1 + if n > 0: + weights = weights / n + + loss = F.binary_cross_entropy_with_logits( + pred, target, weights, reduction='sum') / tot + return loss * self.loss_weight + + +@LOSSES.register_module +class GHMR(nn.Module): + """GHM Regression Loss. + + Details of the theorem can be viewed in the paper + "Gradient Harmonized Single-stage Detector" + https://arxiv.org/abs/1811.05181 + + Args: + mu (float): The parameter for the Authentic Smooth L1 loss. + bins (int): Number of the unit regions for distribution calculation. + momentum (float): The parameter for moving average. + loss_weight (float): The weight of the total GHM-R loss. + """ + def __init__( + self, + mu=0.02, + bins=10, + momentum=0, + loss_weight=1.0): + super(GHMR, self).__init__() + self.mu = mu + self.bins = bins + self.edges = torch.arange(bins + 1).float().cuda() / bins + self.edges[-1] = 1e3 + self.momentum = momentum + if momentum > 0: + self.acc_sum = torch.zeros(bins).cuda() + self.loss_weight = loss_weight + + def forward(self, pred, target, label_weight, avg_factor=None): + """Calculate the GHM-R loss. + + Args: + pred (float tensor of size [batch_num, 4 (* class_num)]): + The prediction of box regression layer. Channel number can be 4 + or 4 * class_num depending on whether it is class-agnostic. + target (float tensor of size [batch_num, 4 (* class_num)]): + The target regression values with the same size of pred. + label_weight (float tensor of size [batch_num, 4 (* class_num)]): + The weight of each sample, 0 if ignored. + Returns: + The gradient harmonized loss. + """ + mu = self.mu + edges = self.edges + mmt = self.momentum + + # ASL1 loss + diff = pred - target + loss = torch.sqrt(diff * diff + mu * mu) - mu + + # gradient length + g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach() + weights = torch.zeros_like(g) + + valid = label_weight > 0 + tot = max(label_weight.float().sum().item(), 1.0) + n = 0 # n: valid bins + for i in range(self.bins): + inds = (g >= edges[i]) & (g < edges[i+1]) & valid + num_in_bin = inds.sum().item() + if num_in_bin > 0: + n += 1 + if mmt > 0: + self.acc_sum[i] = mmt * self.acc_sum[i] \ + + (1 - mmt) * num_in_bin + weights[inds] = tot / self.acc_sum[i] + else: + weights[inds] = tot / num_in_bin + if n > 0: + weights /= n + + loss = loss * weights + loss = loss.sum() / tot + return loss * self.loss_weight