diff --git a/README.md b/README.md index 2b666b2de461aba476be62d257b55e7c17291d63..0a77858939a8e1de0bac9d2656fa215cf38c629e 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md). | Foveabox | ✓ | ✓ | ☠| ✗ | ✓ | | FreeAnchor | ✓ | ✓ | ☠| ✗ | ✓ | | NAS-FPN | ✓ | ✓ | ☠| ✗ | ✓ | +| ATSS | ✓ | ✓ | ☠| ✗ | ✓ | Other features - [x] DCNv2 diff --git a/configs/atss/README.md b/configs/atss/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2457358ccce01bc3457f6d958b3da04456678211 --- /dev/null +++ b/configs/atss/README.md @@ -0,0 +1,20 @@ +# Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection + + +## Introduction + +``` +@article{zhang2019bridging, + title = {Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection}, + author = {Zhang, Shifeng and Chi, Cheng and Yao, Yongqiang and Lei, Zhen and Li, Stan Z.}, + journal = {arXiv preprint arXiv:1912.02424}, + year = {2019} +} +``` + + +## Results and Models + +| Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download | +|:---------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:| +| R-50 | pytorch | 1x | 3.6 | 0.357 | 12.8 | 39.2 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/atss/atss_r50_fpn_1x_20200113-a7aa251e.pth)| diff --git a/configs/atss/atss_r50_fpn_1x.py b/configs/atss/atss_r50_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..e19a39da967e588072692505048bf7cc9d9ff128 --- /dev/null +++ b/configs/atss/atss_r50_fpn_1x.py @@ -0,0 +1,127 @@ +# model settings +model = dict( + type='ATSS', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + extra_convs_on_inputs=False, + num_outs=5), + bbox_head=dict( + type='ATSSHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=8, + scales_per_octave=1, + anchor_ratios=[1.0], + anchor_strides=[8, 16, 32, 64, 128], + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2], + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))) +# training and testing settings +train_cfg = dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.6), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/atss_r50_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index 6b93186b988de8e3c5424d9d8c0ad27217a1d941..954a4772a2114b8983d8a3fa656104a1f64b8bdb 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -284,6 +284,9 @@ Please refer to [Rethinking ImageNet Pre-training](https://github.com/open-mmlab ### NAS-FPN Please refer to [NAS-FPN](https://github.com/open-mmlab/mmdetection/blob/master/configs/nas_fpn) for details. +### ATSS +Please refer to [ATSS](https://github.com/open-mmlab/mmdetection/blob/master/configs/atss) for details. + ### Other datasets We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py index dfeb3b407300beee4d530c111b4187693b190e2b..06e2d12323bbf9dc95ba84a6aaddfa1ac6d5f5bd 100644 --- a/mmdet/core/anchor/__init__.py +++ b/mmdet/core/anchor/__init__.py @@ -1,10 +1,12 @@ from .anchor_generator import AnchorGenerator -from .anchor_target import anchor_inside_flags, anchor_target +from .anchor_target import (anchor_inside_flags, anchor_target, + images_to_levels, unmap) from .guided_anchor_target import ga_loc_target, ga_shape_target from .point_generator import PointGenerator from .point_target import point_target __all__ = [ 'AnchorGenerator', 'anchor_target', 'anchor_inside_flags', 'ga_loc_target', - 'ga_shape_target', 'PointGenerator', 'point_target' + 'ga_shape_target', 'PointGenerator', 'point_target', 'images_to_levels', + 'unmap' ] diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 93eebb775be7720f232f122050d5f753117f7731..4ed1d5643184e4c9b700bd9dd1ca39072f1e740a 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -1,10 +1,11 @@ from .approx_max_iou_assigner import ApproxMaxIoUAssigner from .assign_result import AssignResult +from .atss_assigner import ATSSAssigner from .base_assigner import BaseAssigner from .max_iou_assigner import MaxIoUAssigner from .point_assigner import PointAssigner __all__ = [ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', - 'PointAssigner' + 'PointAssigner', 'ATSSAssigner' ] diff --git a/mmdet/core/bbox/assigners/atss_assigner.py b/mmdet/core/bbox/assigners/atss_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..e442ac709d854a43883387ae88246776b5744743 --- /dev/null +++ b/mmdet/core/bbox/assigners/atss_assigner.py @@ -0,0 +1,159 @@ +import torch + +from ..geometry import bbox_overlaps +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +class ATSSAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `0` or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + topk (float): number of bbox selected in each level + """ + + def __init__(self, topk): + self.topk = topk + + # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py + + def assign(self, + bboxes, + num_level_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + gt_labels=None): + """Assign gt to bboxes. + + The assignment is done in following steps + + 1. compute iou between all bbox (bbox of all pyramid levels) and gt + 2. compute center distance between all bbox and gt + 3. on each pyramid level, for each gt, select k bbox whose center + are closest to the gt center, so we total select k*l bbox as + candidates for each gt + 4. get corresponding iou for the these candidates, and compute the + mean and std, set mean + std as the iou threshold + 5. select these candidates whose iou are greater than or equal to + the threshold as postive + 6. limit the positive sample's center in gt + + + Args: + bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). + num_level_bboxes (List): num of bboxes in each level + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + INF = 100000000 + bboxes = bboxes[:, :4] + num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0) + + # compute iou between all bbox and gt + overlaps = bbox_overlaps(bboxes, gt_bboxes) + + # assign 0 by default + assigned_gt_inds = overlaps.new_full((num_bboxes, ), + 0, + dtype=torch.long) + + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = overlaps.new_zeros((num_bboxes, )) + if num_gt == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = overlaps.new_zeros((num_bboxes, ), + dtype=torch.long) + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + # compute center distance between all bbox and gt + gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 + gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 + gt_points = torch.stack((gt_cx, gt_cy), dim=1) + + bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0 + bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0 + bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1) + + distances = (bboxes_points[:, None, :] - + gt_points[None, :, :]).pow(2).sum(-1).sqrt() + + # Selecting candidates based on the center distance + candidate_idxs = [] + start_idx = 0 + for level, bboxes_per_level in enumerate(num_level_bboxes): + # on each pyramid level, for each gt, + # select k bbox whose center are closest to the gt center + end_idx = start_idx + bboxes_per_level + distances_per_level = distances[start_idx:end_idx, :] + _, topk_idxs_per_level = distances_per_level.topk( + self.topk, dim=0, largest=False) + candidate_idxs.append(topk_idxs_per_level + start_idx) + start_idx = end_idx + candidate_idxs = torch.cat(candidate_idxs, dim=0) + + # get corresponding iou for the these candidates, and compute the + # mean and std, set mean + std as the iou threshold + candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)] + overlaps_mean_per_gt = candidate_overlaps.mean(0) + overlaps_std_per_gt = candidate_overlaps.std(0) + overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt + + is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :] + + # limit the positive sample's center in gt + for gt_idx in range(num_gt): + candidate_idxs[:, gt_idx] += gt_idx * num_bboxes + ep_bboxes_cx = bboxes_cx.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + ep_bboxes_cy = bboxes_cy.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + candidate_idxs = candidate_idxs.view(-1) + + # calculate the left, top, right, bottom distance between positive + # bbox center and gt side + l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0] + t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt) + b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt) + is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01 + is_pos = is_pos & is_in_gts + + # if an anchor box is assigned to multiple gts, + # the one with the highest IoU will be selected. + overlaps_inf = torch.full_like(overlaps, + -INF).t().contiguous().view(-1) + index = candidate_idxs.view(-1)[is_pos.view(-1)] + overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] + overlaps_inf = overlaps_inf.view(num_gt, -1).t() + + max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1) + assigned_gt_inds[ + max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, )) + pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + return AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) diff --git a/mmdet/models/anchor_heads/__init__.py b/mmdet/models/anchor_heads/__init__.py index c693c0fd209cf34eaf01a6ce5b28e40d883b6ea6..3364c4c9113a7f9f816bd6a56f515d186a9821ce 100644 --- a/mmdet/models/anchor_heads/__init__.py +++ b/mmdet/models/anchor_heads/__init__.py @@ -1,4 +1,5 @@ from .anchor_head import AnchorHead +from .atss_head import ATSSHead from .fcos_head import FCOSHead from .fovea_head import FoveaHead from .free_anchor_retina_head import FreeAnchorRetinaHead @@ -14,5 +15,6 @@ from .ssd_head import SSDHead __all__ = [ 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead', - 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead' + 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', + 'ATSSHead' ] diff --git a/mmdet/models/anchor_heads/atss_head.py b/mmdet/models/anchor_heads/atss_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e0f2e0abc2703beef4a33d12d53b78393cb7bd5c --- /dev/null +++ b/mmdet/models/anchor_heads/atss_head.py @@ -0,0 +1,487 @@ +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from mmcv.cnn import normal_init + +from mmdet.core import (PseudoSampler, anchor_inside_flags, bbox2delta, + build_assigner, delta2bbox, force_fp32, + images_to_levels, multi_apply, multiclass_nms, unmap) +from ..builder import build_loss +from ..registry import HEADS +from ..utils import ConvModule, Scale, bias_init_with_prob +from .anchor_head import AnchorHead + + +def reduce_mean(tensor): + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.reduce_op.SUM) + return tensor + + +@HEADS.register_module +class ATSSHead(AnchorHead): + """ + Bridging the Gap Between Anchor-based and Anchor-free Detection via + Adaptive Training Sample Selection + + ATSS head structure is similar with FCOS, however ATSS use anchor boxes + and assign label by Adaptive Training Sample Selection instead max-iou. + + https://arxiv.org/abs/1912.02424 + """ + + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + octave_base_scale=4, + scales_per_octave=1, + conv_cfg=None, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + loss_centerness=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + **kwargs): + self.stacked_convs = stacked_convs + self.octave_base_scale = octave_base_scale + self.scales_per_octave = scales_per_octave + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + octave_scales = np.array( + [2**(i / scales_per_octave) for i in range(scales_per_octave)]) + anchor_scales = octave_scales * octave_base_scale + super(ATSSHead, self).__init__( + num_classes, in_channels, anchor_scales=anchor_scales, **kwargs) + + self.loss_centerness = build_loss(loss_centerness) + + def _init_layers(self): + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.atss_cls = nn.Conv2d( + self.feat_channels, + self.num_anchors * self.cls_out_channels, + 3, + padding=1) + self.atss_reg = nn.Conv2d( + self.feat_channels, self.num_anchors * 4, 3, padding=1) + self.atss_centerness = nn.Conv2d( + self.feat_channels, self.num_anchors * 1, 3, padding=1) + self.scales = nn.ModuleList([Scale(1.0) for _ in self.anchor_strides]) + + def init_weights(self): + for m in self.cls_convs: + normal_init(m.conv, std=0.01) + for m in self.reg_convs: + normal_init(m.conv, std=0.01) + bias_cls = bias_init_with_prob(0.01) + normal_init(self.atss_cls, std=0.01, bias=bias_cls) + normal_init(self.atss_reg, std=0.01) + normal_init(self.atss_centerness, std=0.01) + + def forward(self, feats): + return multi_apply(self.forward_single, feats, self.scales) + + def forward_single(self, x, scale): + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.atss_cls(cls_feat) + # we just follow atss, not apply exp in bbox_pred + bbox_pred = scale(self.atss_reg(reg_feat)).float() + centerness = self.atss_centerness(reg_feat) + return cls_score, bbox_pred, centerness + + def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels, + label_weights, bbox_targets, num_total_samples, cfg): + + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + centerness = centerness.permute(0, 2, 3, 1).reshape(-1) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + + # classification loss + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=num_total_samples) + + pos_inds = torch.nonzero(labels).squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + pos_centerness = centerness[pos_inds] + + centerness_targets = self.centerness_target( + pos_anchors, pos_bbox_targets) + pos_decode_bbox_pred = delta2bbox(pos_anchors, pos_bbox_pred, + self.target_means, + self.target_stds) + pos_decode_bbox_targets = delta2bbox(pos_anchors, pos_bbox_targets, + self.target_means, + self.target_stds) + + # regression loss + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=centerness_targets, + avg_factor=1.0) + + # centerness loss + loss_centerness = self.loss_centerness( + pos_centerness, + centerness_targets, + avg_factor=num_total_samples) + + else: + loss_bbox = loss_cls * 0 + loss_centerness = loss_bbox * 0 + centerness_targets = torch.tensor(0).cuda() + + return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses')) + def loss(self, + cls_scores, + bbox_preds, + centernesses, + gt_bboxes, + gt_labels, + img_metas, + cfg, + gt_bboxes_ignore=None): + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == len(self.anchor_generators) + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + cls_reg_targets = self.atss_target( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + cfg, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + if cls_reg_targets is None: + return None + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets + + num_total_samples = reduce_mean( + torch.tensor(num_total_pos).cuda()).item() + num_total_samples = max(num_total_samples, 1.0) + + losses_cls, losses_bbox, loss_centerness,\ + bbox_avg_factor = multi_apply( + self.loss_single, + anchor_list, + cls_scores, + bbox_preds, + centernesses, + labels_list, + label_weights_list, + bbox_targets_list, + num_total_samples=num_total_samples, + cfg=cfg) + + bbox_avg_factor = sum(bbox_avg_factor) + bbox_avg_factor = reduce_mean(bbox_avg_factor).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_centerness=loss_centerness) + + def centerness_target(self, anchors, bbox_targets): + # only calculate pos centerness targets, otherwise there may be nan + gts = delta2bbox(anchors, bbox_targets, self.target_means, + self.target_stds) + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + l_ = anchors_cx - gts[:, 0] + t_ = anchors_cy - gts[:, 1] + r_ = gts[:, 2] - anchors_cx + b_ = gts[:, 3] - anchors_cy + + left_right = torch.stack([l_, r_], dim=1) + top_bottom = torch.stack([t_, b_], dim=1) + centerness = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) + assert not torch.isnan(centerness).any() + return centerness + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses')) + def get_bboxes(self, + cls_scores, + bbox_preds, + centernesses, + img_metas, + cfg, + rescale=False): + + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + device = cls_scores[0].device + mlvl_anchors = [ + self.anchor_generators[i].grid_anchors( + cls_scores[i].size()[-2:], + self.anchor_strides[i], + device=device) for i in range(num_levels) + ] + + result_list = [] + for img_id in range(len(img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + centerness_pred_list = [ + centernesses[i][img_id].detach() for i in range(num_levels) + ] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list, + centerness_pred_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale) + result_list.append(proposals) + return result_list + + def get_bboxes_single(self, + cls_scores, + bbox_preds, + centernesses, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False): + assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_centerness = [] + for cls_score, bbox_pred, centerness, anchors in zip( + cls_scores, bbox_preds, centernesses, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() + + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = (scores * centerness[:, None]).max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + centerness = centerness[topk_inds] + + bboxes = delta2bbox(anchors, bbox_pred, self.target_means, + self.target_stds, img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_centerness.append(centerness) + + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + + mlvl_scores = torch.cat(mlvl_scores) + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) + mlvl_centerness = torch.cat(mlvl_centerness) + + det_bboxes, det_labels = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_centerness) + return det_bboxes, det_labels + + def atss_target(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + cfg, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """ + almost the same with anchor_target, with a little modification, + here we need return the anchor + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( + self.atss_target_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + cfg=cfg, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, num_total_pos, + num_total_neg) + + def atss_target_single(self, + flat_anchors, + valid_flags, + num_level_anchors, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + cfg, + label_channels=1, + unmap_outputs=True): + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 6 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + num_level_anchors_inside = self.get_num_level_anchors_inside( + num_level_anchors, inside_flags) + bbox_assigner = build_assigner(cfg.assigner) + assign_result = bbox_assigner.assign(anchors, num_level_anchors_inside, + gt_bboxes, gt_bboxes_ignore, + gt_labels) + + bbox_sampler = PseudoSampler() + sampling_result = bbox_sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes, + sampling_result.pos_gt_bboxes, + self.target_means, self.target_stds) + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + if gt_labels is None: + labels[pos_inds] = 1 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap(labels, num_total_anchors, inside_flags) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return (anchors, labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds) + + def get_num_level_anchors_inside(self, num_level_anchors, inside_flags): + split_inside_flags = torch.split(inside_flags, num_level_anchors) + num_level_anchors_inside = [ + int(flags.sum()) for flags in split_inside_flags + ] + return num_level_anchors_inside diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index f9f83c621def00b84052ff41266180cfa8dec788..70a53d876952bce9a0916785c43e5c5dcd7c73ba 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -1,3 +1,4 @@ +from .atss import ATSS from .base import BaseDetector from .cascade_rcnn import CascadeRCNN from .double_head_rcnn import DoubleHeadRCNN @@ -16,7 +17,7 @@ from .single_stage import SingleStageDetector from .two_stage import TwoStageDetector __all__ = [ - 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', + 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA' diff --git a/mmdet/models/detectors/atss.py b/mmdet/models/detectors/atss.py new file mode 100644 index 0000000000000000000000000000000000000000..ac22bf928673334c52a67e653faadb63b8df9ea8 --- /dev/null +++ b/mmdet/models/detectors/atss.py @@ -0,0 +1,16 @@ +from ..registry import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module +class ATSS(SingleStageDetector): + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(ATSS, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py index 992741ae4b8cde5897782ecb4b420c6971ba8b8b..07731d71053c92bf8542180f20ca613d45acbed6 100644 --- a/mmdet/models/losses/__init__.py +++ b/mmdet/models/losses/__init__.py @@ -4,7 +4,8 @@ from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy) from .focal_loss import FocalLoss, sigmoid_focal_loss from .ghm_loss import GHMC, GHMR -from .iou_loss import BoundedIoULoss, IoULoss, bounded_iou_loss, iou_loss +from .iou_loss import (BoundedIoULoss, GIoULoss, IoULoss, bounded_iou_loss, + iou_loss) from .mse_loss import MSELoss, mse_loss from .smooth_l1_loss import SmoothL1Loss, smooth_l1_loss from .utils import reduce_loss, weight_reduce_loss, weighted_loss @@ -14,6 +15,6 @@ __all__ = [ 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss', 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss', 'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss', - 'IoULoss', 'BoundedIoULoss', 'GHMC', 'GHMR', 'reduce_loss', + 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'GHMC', 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss' ] diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py index cf2499463d94a557e9ca85844c1717f186be3833..c19c1d1d6af86926c31c9a28e974ebfc046c2b33 100644 --- a/mmdet/models/losses/iou_loss.py +++ b/mmdet/models/losses/iou_loss.py @@ -69,6 +69,51 @@ def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3): return loss +@weighted_loss +def giou_loss(pred, target, eps=1e-7): + """ + Generalized Intersection over Union: A Metric and A Loss for + Bounding Box Regression + https://arxiv.org/abs/1902.09630 + + code refer to: + https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py#L36 + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Eps to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + # overlap + lt = torch.max(pred[:, :2], target[:, :2]) + rb = torch.min(pred[:, 2:], target[:, 2:]) + wh = (rb - lt + 1).clamp(min=0) + overlap = wh[:, 0] * wh[:, 1] + + # union + ap = (pred[:, 2] - pred[:, 0] + 1) * (pred[:, 3] - pred[:, 1] + 1) + ag = (target[:, 2] - target[:, 0] + 1) * (target[:, 3] - target[:, 1] + 1) + union = ap + ag - overlap + eps + + # IoU + ious = overlap / union + + # enclose area + enclose_x1y1 = torch.min(pred[:, :2], target[:, :2]) + enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:]) + enclose_wh = (enclose_x2y2 - enclose_x1y1 + 1).clamp(min=0) + enclose_area = enclose_wh[:, 0] * enclose_wh[:, 1] + eps + + # GIoU + gious = ious - (enclose_area - union) / enclose_area + loss = 1 - gious + return loss + + @LOSSES.register_module class IoULoss(nn.Module): @@ -133,3 +178,35 @@ class BoundedIoULoss(nn.Module): avg_factor=avg_factor, **kwargs) return loss + + +@LOSSES.register_module +class GIoULoss(nn.Module): + + def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0): + super(GIoULoss, self).__init__() + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + if weight is not None and not torch.any(weight > 0): + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss = self.loss_weight * giou_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss