Skip to content
Snippets Groups Projects
Commit d1c71db7 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Reimplement "FreeAnchor: Learning to Match Anchors for Visual Object Detection" (#1391)

* add FreeAnchorRetinanet

* fix std

* refactor code

* add cfg

* fix bug in cfg

* add annotation and clean

* add readme

* fix isort error

* replace clip with torch.clamp

* fix init typo

* refactoring

* clean

* update benchmark

* rename and delete some var

* minor fix
parent 01c0489d
No related branches found
No related tags found
No related merge requests found
# FreeAnchor: Learning to Match Anchors for Visual Object Detection
## Introduction
```
@inproceedings{zhang2019freeanchor,
title = {{FreeAnchor}: Learning to Match Anchors for Visual Object Detection},
author = {Zhang, Xiaosong and Wan, Fang and Liu, Chang and Ji, Rongrong and Ye, Qixiang},
booktitle = {Neural Information Processing Systems},
year = {2019}
}
```
## Results and Models
| Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
|:---------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | 1x | 4.7 | 0.322 | 12.0 | 38.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/free_anchor/retinanet_free_anchor_r50_fpn_1x_20190914-84db6585.pth) |
| R-101 | pytorch | 1x | 6.6 | 0.437 | 9.7 | 40.3 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/free_anchor/retinanet_free_anchor_r101_fpn_1x_20190914-c4e4db81.pth) |
| X-101-32x4d | pytorch | 1x | 7.8 | 0.640 | 8.4 | 42.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/free_anchor/retinanet_free_anchor_x101-32x4d_fpn_1x_20190914-eb73b804.pth) |
**Notes:**
- We use 8 GPUs with 2 images/GPU.
# model settings
model = dict(
type='RetinaNet',
pretrained='torchvision://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
num_outs=5),
bbox_head=dict(
type='FreeAnchorRetinaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
octave_base_scale=4,
scales_per_octave=3,
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[8, 16, 32, 64, 128],
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2],
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.75)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/retinanet_free_anchor_r101_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='RetinaNet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
num_outs=5),
bbox_head=dict(
type='FreeAnchorRetinaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
octave_base_scale=4,
scales_per_octave=3,
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[8, 16, 32, 64, 128],
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2],
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.75)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/retinanet_free_anchor_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='RetinaNet',
pretrained='open-mmlab://resnext101_32x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=32,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
num_outs=5),
bbox_head=dict(
type='FreeAnchorRetinaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
octave_base_scale=4,
scales_per_octave=3,
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[8, 16, 32, 64, 128],
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2],
loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=0.75)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/retinanet_free_anchor_x101-32x4d_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
from .anchor_head import AnchorHead
from .fcos_head import FCOSHead
from .fovea_head import FoveaHead
from .free_anchor_retina_head import FreeAnchorRetinaHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
......@@ -12,5 +13,5 @@ from .ssd_head import SSDHead
__all__ = [
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'GARetinaHead', 'SSDHead', 'FCOSHead',
'RepPointsHead', 'FoveaHead'
'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead'
]
import torch
import torch.nn.functional as F
from mmdet.core import bbox2delta, bbox_overlaps, delta2bbox
from ..registry import HEADS
from .retina_head import RetinaHead
@HEADS.register_module
class FreeAnchorRetinaHead(RetinaHead):
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
conv_cfg=None,
norm_cfg=None,
pre_anchor_topk=50,
bbox_thr=0.6,
gamma=2.0,
alpha=0.5,
**kwargs):
super(FreeAnchorRetinaHead,
self).__init__(num_classes, in_channels, stacked_convs,
octave_base_scale, scales_per_octave, conv_cfg,
norm_cfg, **kwargs)
self.pre_anchor_topk = pre_anchor_topk
self.bbox_thr = bbox_thr
self.gamma = gamma
self.alpha = alpha
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, _ = self.get_anchors(featmap_sizes, img_metas)
anchors = [torch.cat(anchor) for anchor in anchor_list]
# concatenate each level
cls_scores = [
cls.permute(0, 2, 3,
1).reshape(cls.size(0), -1, self.cls_out_channels)
for cls in cls_scores
]
bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
for bbox_pred in bbox_preds
]
cls_scores = torch.cat(cls_scores, dim=1)
bbox_preds = torch.cat(bbox_preds, dim=1)
cls_prob = torch.sigmoid(cls_scores)
box_prob = []
num_pos = 0
positive_losses = []
for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_,
bbox_preds_) in enumerate(
zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)):
gt_labels_ -= 1
with torch.no_grad():
# box_localization: a_{j}^{loc}, shape: [j, 4]
pred_boxes = delta2bbox(anchors_, bbox_preds_,
self.target_means, self.target_stds)
# object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes)
# object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
t1 = self.bbox_thr
t2 = object_box_iou.max(
dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)
object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
min=0, max=1)
# object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
num_obj = gt_labels_.size(0)
indices = torch.stack(
[torch.arange(num_obj).type_as(gt_labels_), gt_labels_],
dim=0)
object_cls_box_prob = torch.sparse_coo_tensor(
indices, object_box_prob)
# image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
"""
from "start" to "end" implement:
image_box_iou = torch.sparse.max(object_cls_box_prob,
dim=0).t()
"""
# start
box_cls_prob = torch.sparse.sum(
object_cls_box_prob, dim=0).to_dense()
indices = torch.nonzero(box_cls_prob).t_()
if indices.numel() == 0:
image_box_prob = torch.zeros(
anchors_.size(0),
self.cls_out_channels).type_as(object_box_prob)
else:
nonzero_box_prob = torch.where(
(gt_labels_.unsqueeze(dim=-1) == indices[0]),
object_box_prob[:, indices[1]],
torch.tensor(
[0]).type_as(object_box_prob)).max(dim=0).values
# upmap to shape [j, c]
image_box_prob = torch.sparse_coo_tensor(
indices.flip([0]),
nonzero_box_prob,
size=(anchors_.size(0),
self.cls_out_channels)).to_dense()
# end
box_prob.append(image_box_prob)
# construct bags for objects
match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_)
_, matched = torch.topk(
match_quality_matrix,
self.pre_anchor_topk,
dim=1,
sorted=False)
del match_quality_matrix
# matched_cls_prob: P_{ij}^{cls}
matched_cls_prob = torch.gather(
cls_prob_[matched], 2,
gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
1)).squeeze(2)
# matched_box_prob: P_{ij}^{loc}
matched_anchors = anchors_[matched]
matched_object_targets = bbox2delta(
matched_anchors,
gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors),
self.target_means, self.target_stds)
loss_bbox = self.loss_bbox(
bbox_preds_[matched],
matched_object_targets,
reduction_override='none').sum(-1)
matched_box_prob = torch.exp(-loss_bbox)
# positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
num_pos += len(gt_bboxes_)
positive_losses.append(
self.positive_bag_loss(matched_cls_prob, matched_box_prob))
positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)
# box_prob: P{a_{j} \in A_{+}}
box_prob = torch.stack(box_prob, dim=0)
# negative_loss:
# \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
1, num_pos * self.pre_anchor_topk)
losses = {
'positive_bag_loss': positive_loss,
'negative_bag_loss': negative_loss
}
return losses
def positive_bag_loss(self, matched_cls_prob, matched_box_prob):
# bag_prob = Mean-max(matched_prob)
matched_prob = matched_cls_prob * matched_box_prob
weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
weight /= weight.sum(dim=1).unsqueeze(dim=-1)
bag_prob = (weight * matched_prob).sum(dim=1)
# positive_bag_loss = -self.alpha * log(bag_prob)
return self.alpha * F.binary_cross_entropy(
bag_prob, torch.ones_like(bag_prob), reduction='none')
def negative_bag_loss(self, cls_prob, box_prob):
prob = cls_prob * (1 - box_prob)
negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
prob, torch.zeros_like(prob), reduction='none')
return (1 - self.alpha) * negative_bag_loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment