Skip to content
Snippets Groups Projects
Unverified Commit 4990aae6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #44 from hellock/dev

Add BBoxAssigner and BBoxSampler components for better modular usage
parents f8fa51c9 bac11303
No related branches found
No related tags found
No related merge requests found
...@@ -43,17 +43,19 @@ model = dict( ...@@ -43,17 +43,19 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rcnn=dict( rcnn=dict(
assigner=dict(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
mask_size=28, mask_size=28,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
crowd_thr=1.1,
roi_batch_size=512,
add_gt_as_proposals=True,
pos_fraction=0.25,
pos_balance_sampling=False,
neg_pos_ub=512,
neg_balance_thr=0,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -32,16 +32,18 @@ model = dict( ...@@ -32,16 +32,18 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rcnn=dict( rcnn=dict(
pos_iou_thr=0.5, assigner=dict(
neg_iou_thr=0.5, pos_iou_thr=0.5,
crowd_thr=1.1, neg_iou_thr=0.5,
roi_batch_size=512, min_pos_iou=0.5,
add_gt_as_proposals=True, ignore_iof_thr=-1),
pos_fraction=0.25, sampler=dict(
pos_balance_sampling=False, num=512,
neg_pos_ub=512, pos_fraction=0.25,
neg_balance_thr=0, neg_pos_ub=-1,
min_pos_iou=0.5, add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5)) test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
......
...@@ -42,30 +42,35 @@ model = dict( ...@@ -42,30 +42,35 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False, pos_iou_thr=0.7,
neg_pos_ub=256, neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0, allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False), debug=False),
rcnn=dict( rcnn=dict(
pos_iou_thr=0.5, assigner=dict(
neg_iou_thr=0.5, pos_iou_thr=0.5,
crowd_thr=1.1, neg_iou_thr=0.5,
roi_batch_size=512, min_pos_iou=0.5,
add_gt_as_proposals=True, ignore_iof_thr=-1),
pos_fraction=0.25, sampler=dict(
pos_balance_sampling=False, num=512,
neg_pos_ub=512, pos_fraction=0.25,
neg_balance_thr=0, neg_pos_ub=-1,
min_pos_iou=0.5, add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -53,31 +53,36 @@ model = dict( ...@@ -53,31 +53,36 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False, pos_iou_thr=0.7,
neg_pos_ub=256, neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0, allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False), debug=False),
rcnn=dict( rcnn=dict(
assigner=dict(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True,
pos_balance_sampling=False,
neg_balance_thr=0),
mask_size=28, mask_size=28,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
crowd_thr=1.1,
roi_batch_size=512,
add_gt_as_proposals=True,
pos_fraction=0.25,
pos_balance_sampling=False,
neg_pos_ub=512,
neg_balance_thr=0,
min_pos_iou=0.5,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
test_cfg = dict( test_cfg = dict(
......
...@@ -27,16 +27,19 @@ model = dict( ...@@ -27,16 +27,19 @@ model = dict(
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
pos_fraction=0.5, assigner=dict(
pos_balance_sampling=False, pos_iou_thr=0.7,
neg_pos_ub=256, neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False,
pos_balance_sampling=False,
neg_balance_thr=0),
allowed_border=0, allowed_border=0,
crowd_thr=1.1,
anchor_batch_size=256,
pos_iou_thr=0.7,
neg_iou_thr=0.3,
neg_balance_thr=0,
min_pos_iou=0.3,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0, smoothl1_beta=1 / 9.0,
debug=False)) debug=False))
......
import torch import torch
from ..bbox import bbox_assign, bbox2delta, bbox_sampling from ..bbox import assign_and_sample, bbox2delta
from ..utils import multi_apply from ..utils import multi_apply
...@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, ...@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
return (None, ) * 6 return (None, ) * 6
# assign gt and sample anchors # assign gt and sample anchors
anchors = flat_anchors[inside_flags, :] anchors = flat_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign( _, sampling_result = assign_and_sample(anchors, gt_bboxes, None, None, cfg)
anchors,
gt_bboxes,
pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou)
pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size,
cfg.pos_fraction, cfg.neg_pos_ub,
cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors) bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors)
labels = torch.zeros_like(assigned_gt_inds) labels = anchors.new_zeros((num_valid_anchors, ))
label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype) label_weights = anchors.new_zeros((num_valid_anchors, ))
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0: if len(pos_inds) > 0:
pos_anchors = anchors[pos_inds, :] pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :] sampling_result.pos_gt_bboxes,
pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means, target_means, target_stds)
target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0 bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = 1 labels[pos_inds] = 1
......
from .geometry import bbox_overlaps from .geometry import bbox_overlaps
from .sampling import (random_choice, bbox_assign, bbox_assign_wrt_overlaps, from .assignment import BBoxAssigner, AssignResult
bbox_sampling, bbox_sampling_pos, bbox_sampling_neg, from .sampling import (BBoxSampler, SamplingResult, assign_and_sample,
sample_bboxes) random_choice)
from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping, from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping,
bbox_mapping_back, bbox2roi, roi2bbox, bbox2result) bbox_mapping_back, bbox2roi, roi2bbox, bbox2result)
from .bbox_target import bbox_target from .bbox_target import bbox_target
__all__ = [ __all__ = [
'bbox_overlaps', 'random_choice', 'bbox_assign', 'bbox_overlaps', 'BBoxAssigner', 'AssignResult', 'BBoxSampler',
'bbox_assign_wrt_overlaps', 'bbox_sampling', 'bbox_sampling_pos', 'SamplingResult', 'assign_and_sample', 'random_choice', 'bbox2delta',
'bbox_sampling_neg', 'sample_bboxes', 'bbox2delta', 'delta2bbox', 'delta2bbox', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi',
'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'roi2bbox', 'bbox2result', 'bbox_target'
'bbox2result', 'bbox_target'
] ]
import torch
from .geometry import bbox_overlaps
class BBoxAssigner(object):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `-1`, `0`, or a positive integer
indicating the ground truth index.
- -1: don't care
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
`gt_bboxes_ignore` is specified). Negative values mean not
ignoring any bboxes.
"""
def __init__(self,
pos_iou_thr,
neg_iou_thr,
min_pos_iou=.0,
ignore_iof_thr=-1):
self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr
self.min_pos_iou = min_pos_iou
self.ignore_iof_thr = ignore_iof_thr
def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to bboxes.
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
will be assigned with -1, 0, or a positive number. -1 means don't care,
0 means negative sample, positive number is the index (1-based) of
assigned gt.
The assignment is done in following steps, the order matters.
1. assign every bbox to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals (may be more than
one) to itself
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
raise ValueError('No gt or bboxes')
bboxes = bboxes[:, :4]
overlaps = bbox_overlaps(bboxes, gt_bboxes)
if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and (
gt_bboxes_ignore.numel() > 0):
ignore_overlaps = bbox_overlaps(
bboxes, gt_bboxes_ignore, mode='iof')
ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
ignore_bboxes_inds = torch.nonzero(
ignore_max_overlaps > self.ignore_iof_thr).squeeze()
if ignore_bboxes_inds.numel() > 0:
overlaps[ignore_bboxes_inds[:, 0], :] = -1
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
return assign_result
def assign_wrt_overlaps(self, overlaps, gt_labels=None):
"""Assign w.r.t. the overlaps of bboxes with gts.
Args:
overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes,
shape(n, k).
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if overlaps.numel() == 0:
raise ValueError('No gt or proposals')
num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
# 1. assign -1 by default
assigned_gt_inds = overlaps.new_full(
(num_bboxes, ), -1, dtype=torch.long)
assert overlaps.size() == (num_bboxes, num_gts)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps, argmax_overlaps = overlaps.max(dim=1)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
# 2. assign negative: below
if isinstance(self.neg_iou_thr, float):
assigned_gt_inds[(max_overlaps >= 0)
& (max_overlaps < self.neg_iou_thr)] = 0
elif isinstance(self.neg_iou_thr, tuple):
assert len(self.neg_iou_thr) == 2
assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
& (max_overlaps < self.neg_iou_thr[1])] = 0
# 3. assign positive: above positive IoU threshold
pos_inds = max_overlaps >= self.pos_iou_thr
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
# 4. assign fg: for each gt, proposals with highest IoU
for i in range(num_gts):
if gt_max_overlaps[i] >= self.min_pos_iou:
assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
assigned_gt_inds[pos_inds] - 1]
else:
assigned_labels = None
return AssignResult(
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
class AssignResult(object):
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.num_gts = num_gts
self.gt_inds = gt_inds
self.max_overlaps = max_overlaps
self.labels = labels
def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
...@@ -4,23 +4,23 @@ from .transforms import bbox2delta ...@@ -4,23 +4,23 @@ from .transforms import bbox2delta
from ..utils import multi_apply from ..utils import multi_apply
def bbox_target(pos_proposals_list, def bbox_target(pos_bboxes_list,
neg_proposals_list, neg_bboxes_list,
pos_gt_bboxes_list, pos_gt_bboxes_list,
pos_gt_labels_list, pos_gt_labels_list,
cfg, cfg,
reg_num_classes=1, reg_classes=1,
target_means=[.0, .0, .0, .0], target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0], target_stds=[1.0, 1.0, 1.0, 1.0],
concat=True): concat=True):
labels, label_weights, bbox_targets, bbox_weights = multi_apply( labels, label_weights, bbox_targets, bbox_weights = multi_apply(
proposal_target_single, bbox_target_single,
pos_proposals_list, pos_bboxes_list,
neg_proposals_list, neg_bboxes_list,
pos_gt_bboxes_list, pos_gt_bboxes_list,
pos_gt_labels_list, pos_gt_labels_list,
cfg=cfg, cfg=cfg,
reg_num_classes=reg_num_classes, reg_classes=reg_classes,
target_means=target_means, target_means=target_means,
target_stds=target_stds) target_stds=target_stds)
...@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list, ...@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list,
return labels, label_weights, bbox_targets, bbox_weights return labels, label_weights, bbox_targets, bbox_weights
def proposal_target_single(pos_proposals, def bbox_target_single(pos_bboxes,
neg_proposals, neg_bboxes,
pos_gt_bboxes, pos_gt_bboxes,
pos_gt_labels, pos_gt_labels,
cfg, cfg,
reg_num_classes=1, reg_classes=1,
target_means=[.0, .0, .0, .0], target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]): target_stds=[1.0, 1.0, 1.0, 1.0]):
num_pos = pos_proposals.size(0) num_pos = pos_bboxes.size(0)
num_neg = neg_proposals.size(0) num_neg = neg_bboxes.size(0)
num_samples = num_pos + num_neg num_samples = num_pos + num_neg
labels = pos_proposals.new_zeros(num_samples, dtype=torch.long) labels = pos_bboxes.new_zeros(num_samples, dtype=torch.long)
label_weights = pos_proposals.new_zeros(num_samples) label_weights = pos_bboxes.new_zeros(num_samples)
bbox_targets = pos_proposals.new_zeros(num_samples, 4) bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
bbox_weights = pos_proposals.new_zeros(num_samples, 4) bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
if num_pos > 0: if num_pos > 0:
labels[:num_pos] = pos_gt_labels labels[:num_pos] = pos_gt_labels
pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
label_weights[:num_pos] = pos_weight label_weights[:num_pos] = pos_weight
pos_bbox_targets = bbox2delta(pos_proposals, pos_gt_bboxes, pos_bbox_targets = bbox2delta(pos_bboxes, pos_gt_bboxes, target_means,
target_means, target_stds) target_stds)
bbox_targets[:num_pos, :] = pos_bbox_targets bbox_targets[:num_pos, :] = pos_bbox_targets
bbox_weights[:num_pos, :] = 1 bbox_weights[:num_pos, :] = 1
if num_neg > 0: if num_neg > 0:
label_weights[-num_neg:] = 1.0 label_weights[-num_neg:] = 1.0
if reg_num_classes > 1: if reg_classes > 1:
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights, bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
labels, reg_num_classes) labels, reg_classes)
return labels, label_weights, bbox_targets, bbox_weights return labels, label_weights, bbox_targets, bbox_weights
......
import numpy as np import numpy as np
import torch import torch
from .geometry import bbox_overlaps from .assignment import BBoxAssigner
def random_choice(gallery, num): def random_choice(gallery, num):
...@@ -21,323 +21,207 @@ def random_choice(gallery, num): ...@@ -21,323 +21,207 @@ def random_choice(gallery, num):
return gallery[rand_inds] return gallery[rand_inds]
def bbox_assign(proposals, def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
gt_bboxes, bbox_assigner = BBoxAssigner(**cfg.assigner)
gt_bboxes_ignore=None, bbox_sampler = BBoxSampler(**cfg.sampler)
gt_labels=None, assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore,
pos_iou_thr=0.5, gt_labels)
neg_iou_thr=0.5, sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes,
min_pos_iou=.0, gt_labels)
crowd_thr=-1): return assign_result, sampling_result
"""Assign a corresponding gt bbox or background to each proposal/anchor.
Each proposals will be assigned with `-1`, `0`, or a positive integer.
- -1: don't care class BBoxSampler(object):
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
If `gt_bboxes_ignore` is specified, bboxes which have iof (intersection
over foreground) with `gt_bboxes_ignore` above `crowd_thr` will be ignored.
Args:
proposals (Tensor): Proposals or RPN anchors, shape (n, 4).
gt_bboxes (Tensor): Ground truth bboxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): shape(m, 4).
gt_labels (Tensor, optional): shape (k, ).
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
crowd_thr (float): IoF threshold for ignoring bboxes. Negative value
for not ignoring any bboxes.
Returns:
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
"""
# calculate overlaps between the proposals and the gt boxes
overlaps = bbox_overlaps(proposals, gt_bboxes)
if overlaps.numel() == 0:
raise ValueError('No gt bbox or proposals')
# ignore proposals according to crowd bboxes
if (crowd_thr > 0) and (gt_bboxes_ignore is
not None) and (gt_bboxes_ignore.numel() > 0):
crowd_overlaps = bbox_overlaps(proposals, gt_bboxes_ignore, mode='iof')
crowd_max_overlaps, _ = crowd_overlaps.max(dim=1)
crowd_bboxes_inds = torch.nonzero(
crowd_max_overlaps > crowd_thr).long()
if crowd_bboxes_inds.numel() > 0:
overlaps[crowd_bboxes_inds, :] = -1
return bbox_assign_wrt_overlaps(overlaps, gt_labels, pos_iou_thr,
neg_iou_thr, min_pos_iou)
def bbox_assign_wrt_overlaps(overlaps,
gt_labels=None,
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=.0):
"""Assign a corresponding gt bbox or background to each proposal/anchor.
This method assign a gt bbox to every proposal, each proposals will be
assigned with -1, 0, or a positive number. -1 means don't care, 0 means
negative sample, positive number is the index (1-based) of assigned gt.
The assignment is done in following steps, the order matters:
1. assign every anchor to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals(may be more than one)
to itself
Args:
overlaps (Tensor): Overlaps between n proposals and k gt_bboxes,
shape(n, k).
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum IoU for a bbox to be considered as a
positive bbox. This argument only affects the 4th step.
Returns:
tuple: (assigned_gt_inds, [assigned_labels], argmax_overlaps,
max_overlaps), shape (n, )
"""
num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
# 1. assign -1 by default
assigned_gt_inds = overlaps.new(num_bboxes).long().fill_(-1)
if overlaps.numel() == 0:
raise ValueError('No gt bbox or proposals')
assert overlaps.size() == (num_bboxes, num_gts)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps, argmax_overlaps = overlaps.max(dim=1)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)
# 2. assign negative: below
if isinstance(neg_iou_thr, float):
assigned_gt_inds[(max_overlaps >= 0)
& (max_overlaps < neg_iou_thr)] = 0
elif isinstance(neg_iou_thr, tuple):
assert len(neg_iou_thr) == 2
assigned_gt_inds[(max_overlaps >= neg_iou_thr[0])
& (max_overlaps < neg_iou_thr[1])] = 0
# 3. assign positive: above positive IoU threshold
pos_inds = max_overlaps >= pos_iou_thr
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
# 4. assign fg: for each gt, proposals with highest IoU
for i in range(num_gts):
if gt_max_overlaps[i] >= min_pos_iou:
assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1
if gt_labels is None:
return assigned_gt_inds, argmax_overlaps, max_overlaps
else:
assigned_labels = assigned_gt_inds.new(num_bboxes).fill_(0)
pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
1]
return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps
def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
"""Balance sampling for positive bboxes/anchors.
1. calculate average positive num for each gt: num_per_gt
2. sample at most num_per_gt positives for each gt
3. random sampling from rest anchors if not enough fg
"""
pos_inds = torch.nonzero(assigned_gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
elif not balance_sampling:
return random_choice(pos_inds, num_expected)
else:
unique_gt_inds = torch.unique(assigned_gt_inds[pos_inds].cpu())
num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = []
for i in unique_gt_inds:
inds = torch.nonzero(assigned_gt_inds == i.item())
if inds.numel() != 0:
inds = inds.squeeze(1)
else:
continue
if len(inds) > num_per_gt:
inds = random_choice(inds, num_per_gt)
sampled_inds.append(inds)
sampled_inds = torch.cat(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(
list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
extra_inds = torch.from_numpy(extra_inds).to(
assigned_gt_inds.device).long()
sampled_inds = torch.cat([sampled_inds, extra_inds])
elif len(sampled_inds) > num_expected:
sampled_inds = random_choice(sampled_inds, num_expected)
return sampled_inds
def bbox_sampling_neg(assigned_gt_inds,
num_expected,
max_overlaps=None,
balance_thr=0,
hard_fraction=0.5):
"""Balance sampling for negative bboxes/anchors.
Negative samples are split into 2 set: hard (balance_thr <= iou <
neg_iou_thr) and easy(iou < balance_thr). The sampling ratio is controlled
by `hard_fraction`.
"""
neg_inds = torch.nonzero(assigned_gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
elif balance_thr <= 0:
# uniform sampling among all negative samples
return random_choice(neg_inds, num_expected)
else:
assert max_overlaps is not None
max_overlaps = max_overlaps.cpu().numpy()
# balance sampling for negative samples
neg_set = set(neg_inds.cpu().numpy())
easy_set = set(
np.where(
np.logical_and(max_overlaps >= 0,
max_overlaps < balance_thr))[0])
hard_set = set(np.where(max_overlaps >= balance_thr)[0])
easy_neg_inds = list(easy_set & neg_set)
hard_neg_inds = list(hard_set & neg_set)
num_expected_hard = int(num_expected * hard_fraction)
if len(hard_neg_inds) > num_expected_hard:
sampled_hard_inds = random_choice(hard_neg_inds, num_expected_hard)
else:
sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
num_expected_easy = num_expected - len(sampled_hard_inds)
if len(easy_neg_inds) > num_expected_easy:
sampled_easy_inds = random_choice(easy_neg_inds, num_expected_easy)
else:
sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
sampled_inds = np.concatenate((sampled_easy_inds, sampled_hard_inds))
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
sampled_inds = np.concatenate((sampled_inds, extra_inds))
sampled_inds = torch.from_numpy(sampled_inds).long().to(
assigned_gt_inds.device)
return sampled_inds
def bbox_sampling(assigned_gt_inds,
num_expected,
pos_fraction,
neg_pos_ub,
pos_balance_sampling=True,
max_overlaps=None,
neg_balance_thr=0,
neg_hard_fraction=0.5):
"""Sample positive and negative bboxes given assigned results. """Sample positive and negative bboxes given assigned results.
Args: Args:
assigned_gt_inds (Tensor): Assigned gt indices for each bbox.
num_expected (int): Expected total samples (pos and neg).
pos_fraction (float): Positive sample fraction. pos_fraction (float): Positive sample fraction.
neg_pos_ub (float): Negative/Positive upper bound. neg_pos_ub (float): Negative/Positive upper bound.
pos_balance_sampling(bool): Whether to sample positive samples around pos_balance_sampling (bool): Whether to sample positive samples around
each gt bbox evenly. each gt bbox evenly.
max_overlaps (Tensor, optional): For each bbox, the max IoU of all gts.
Used for negative balance sampling only.
neg_balance_thr (float, optional): IoU threshold for simple/hard neg_balance_thr (float, optional): IoU threshold for simple/hard
negative balance sampling. negative balance sampling.
neg_hard_fraction (float, optional): Fraction of hard negative samples neg_hard_fraction (float, optional): Fraction of hard negative samples
for negative balance sampling. for negative balance sampling.
Returns:
tuple[Tensor]: positive bbox indices, negative bbox indices.
"""
num_expected_pos = int(num_expected * pos_fraction)
pos_inds = bbox_sampling_pos(assigned_gt_inds, num_expected_pos,
pos_balance_sampling)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_neg_max = int(
neg_pos_ub *
num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub)
num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos)
neg_inds = bbox_sampling_neg(assigned_gt_inds, num_expected_neg,
max_overlaps, neg_balance_thr,
neg_hard_fraction)
neg_inds = neg_inds.unique()
return pos_inds, neg_inds
def sample_bboxes(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates and
ground truth bboxes, which includes 3 steps.
1. Assign gt to each bbox.
2. Add gt bboxes to the sampling pool (optional).
3. Perform positive and negative sampling.
Args:
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_bboxes_ignore (Tensor): Ignored ground truth bboxes. In MS COCO,
`crowd` bboxes are considered as ignored.
gt_labels (Tensor): Class labels of ground truth bboxes.
cfg (dict): Sampling configs.
Returns:
tuple[Tensor]: pos_bboxes, neg_bboxes, pos_assigned_gt_inds,
pos_gt_bboxes, pos_gt_labels
""" """
bboxes = bboxes[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels,
cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou,
cfg.crowd_thr)
if cfg.add_gt_as_proposals:
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
gt_assign_self = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=bboxes.device)
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
assigned_labels = torch.cat([gt_labels, assigned_labels])
pos_inds, neg_inds = bbox_sampling( def __init__(self,
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub, num,
cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr) pos_fraction,
neg_pos_ub=-1,
pos_bboxes = bboxes[pos_inds] add_gt_as_proposals=True,
neg_bboxes = bboxes[neg_inds] pos_balance_sampling=False,
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1 neg_balance_thr=0,
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] neg_hard_fraction=0.5):
pos_gt_labels = assigned_labels[pos_inds] self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_balance_sampling = pos_balance_sampling
self.neg_balance_thr = neg_balance_thr
self.neg_hard_fraction = neg_hard_fraction
def _sample_pos(self, assign_result, num_expected):
"""Balance sampling for positive bboxes/anchors.
1. calculate average positive num for each gt: num_per_gt
2. sample at most num_per_gt positives for each gt
3. random sampling from rest anchors if not enough fg
"""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
elif not self.pos_balance_sampling:
return random_choice(pos_inds, num_expected)
else:
unique_gt_inds = torch.unique(
assign_result.gt_inds[pos_inds].cpu())
num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = []
for i in unique_gt_inds:
inds = torch.nonzero(assign_result.gt_inds == i.item())
if inds.numel() != 0:
inds = inds.squeeze(1)
else:
continue
if len(inds) > num_per_gt:
inds = random_choice(inds, num_per_gt)
sampled_inds.append(inds)
sampled_inds = torch.cat(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(
list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
extra_inds = torch.from_numpy(extra_inds).to(
assign_result.gt_inds.device).long()
sampled_inds = torch.cat([sampled_inds, extra_inds])
elif len(sampled_inds) > num_expected:
sampled_inds = random_choice(sampled_inds, num_expected)
return sampled_inds
def _sample_neg(self, assign_result, num_expected):
"""Balance sampling for negative bboxes/anchors.
Negative samples are split into 2 set: hard (balance_thr <= iou <
neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is
controlled by `hard_fraction`.
"""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
elif self.neg_balance_thr <= 0:
# uniform sampling among all negative samples
return random_choice(neg_inds, num_expected)
else:
max_overlaps = assign_result.max_overlaps.cpu().numpy()
# balance sampling for negative samples
neg_set = set(neg_inds.cpu().numpy())
easy_set = set(
np.where(
np.logical_and(max_overlaps >= 0,
max_overlaps < self.neg_balance_thr))[0])
hard_set = set(np.where(max_overlaps >= self.neg_balance_thr)[0])
easy_neg_inds = list(easy_set & neg_set)
hard_neg_inds = list(hard_set & neg_set)
num_expected_hard = int(num_expected * self.neg_hard_fraction)
if len(hard_neg_inds) > num_expected_hard:
sampled_hard_inds = random_choice(hard_neg_inds,
num_expected_hard)
else:
sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
num_expected_easy = num_expected - len(sampled_hard_inds)
if len(easy_neg_inds) > num_expected_easy:
sampled_easy_inds = random_choice(easy_neg_inds,
num_expected_easy)
else:
sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
sampled_inds = np.concatenate((sampled_easy_inds,
sampled_hard_inds))
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
sampled_inds = np.concatenate((sampled_inds, extra_inds))
sampled_inds = torch.from_numpy(sampled_inds).long().to(
assign_result.gt_inds.device)
return sampled_inds
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
1. Assign gt to each bbox.
2. Add gt bboxes to the sampling pool (optional).
3. Perform positive and negative sampling.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes = bboxes[:, :4]
gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals:
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_flags = torch.cat([
bboxes.new_ones((gt_bboxes.shape[0], ), dtype=torch.uint8),
gt_flags
])
num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self._sample_pos(assign_result, num_expected_pos)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
num_neg_max = int(self.neg_pos_ub *
num_sampled_pos) if num_sampled_pos > 0 else int(
self.neg_pos_ub)
num_expected_neg = min(num_neg_max, num_expected_neg)
neg_inds = self._sample_neg(assign_result, num_expected_neg)
neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
class SamplingResult(object):
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
return (pos_bboxes, neg_bboxes, pos_assigned_gt_inds, pos_gt_bboxes, @property
pos_gt_labels) def bboxes(self):
return torch.cat([self.pos_bboxes, self.neg_bboxes])
...@@ -215,7 +215,7 @@ class CocoDataset(Dataset): ...@@ -215,7 +215,7 @@ class CocoDataset(Dataset):
'proposals should have shapes (n, 4) or (n, 5), ' 'proposals should have shapes (n, 4) or (n, 5), '
'but found {}'.format(proposals.shape)) 'but found {}'.format(proposals.shape))
if proposals.shape[1] == 5: if proposals.shape[1] == 5:
scores = proposals[:, 4] scores = proposals[:, 4, None]
proposals = proposals[:, :4] proposals = proposals[:, :4]
else: else:
scores = None scores = None
...@@ -237,8 +237,8 @@ class CocoDataset(Dataset): ...@@ -237,8 +237,8 @@ class CocoDataset(Dataset):
if self.proposals is not None: if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape, proposals = self.bbox_transform(proposals, img_shape,
scale_factor, flip) scale_factor, flip)
proposals = np.hstack([proposals, scores[:, None] proposals = np.hstack(
]) if scores is not None else proposals [proposals, scores]) if scores is not None else proposals
gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor, gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor,
flip) flip)
gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape, gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape,
...@@ -295,14 +295,14 @@ class CocoDataset(Dataset): ...@@ -295,14 +295,14 @@ class CocoDataset(Dataset):
flip=flip) flip=flip)
if proposal is not None: if proposal is not None:
if proposal.shape[1] == 5: if proposal.shape[1] == 5:
score = proposal[:, 4] score = proposal[:, 4, None]
proposal = proposal[:, :4] proposal = proposal[:, :4]
else: else:
score = None score = None
_proposal = self.bbox_transform(proposal, img_shape, _proposal = self.bbox_transform(proposal, img_shape,
scale_factor, flip) scale_factor, flip)
_proposal = np.hstack([_proposal, score[:, None] _proposal = np.hstack(
]) if score is not None else _proposal [_proposal, score]) if score is not None else _proposal
_proposal = to_tensor(_proposal) _proposal = to_tensor(_proposal)
else: else:
_proposal = None _proposal = None
......
...@@ -59,16 +59,20 @@ class BBoxHead(nn.Module): ...@@ -59,16 +59,20 @@ class BBoxHead(nn.Module):
bbox_pred = self.fc_reg(x) if self.with_reg else None bbox_pred = self.fc_reg(x) if self.with_reg else None
return cls_score, bbox_pred return cls_score, bbox_pred
def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes, def get_target(self, sampling_results, gt_bboxes, gt_labels,
pos_gt_labels, rcnn_train_cfg): rcnn_train_cfg):
reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes pos_proposals = [res.pos_bboxes for res in sampling_results]
neg_proposals = [res.neg_bboxes for res in sampling_results]
pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
reg_classes = 1 if self.reg_class_agnostic else self.num_classes
cls_reg_targets = bbox_target( cls_reg_targets = bbox_target(
pos_proposals, pos_proposals,
neg_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_bboxes,
pos_gt_labels, pos_gt_labels,
rcnn_train_cfg, rcnn_train_cfg,
reg_num_classes, reg_classes,
target_means=self.target_means, target_means=self.target_means,
target_stds=self.target_stds) target_stds=self.target_stds)
return cls_reg_targets return cls_reg_targets
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply)
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
gt_labels, gt_labels,
gt_masks=None, gt_masks=None,
proposals=None): proposals=None):
losses = dict()
x = self.extract_feat(img) x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn: if self.with_rpn:
rpn_outs = self.rpn_head(x) rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
...@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
else: else:
proposal_list = proposals proposal_list = proposals
# assign gts and sample proposals
if self.with_bbox or self.with_mask:
assign_results, sampling_results = multi_apply(
assign_and_sample,
proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
# bbox head forward and loss
if self.with_bbox: if self.with_bbox:
(pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, rois = bbox2roi([res.bboxes for res in sampling_results])
pos_gt_labels) = multi_apply( # TODO: a more flexible way to decide which feature maps to use
sample_bboxes, bbox_feats = self.bbox_roi_extractor(
proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
(labels, label_weights, bbox_targets,
bbox_weights) = self.bbox_head.get_bbox_target(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
self.train_cfg.rcnn)
rois = bbox2roi([
torch.cat([pos, neg], dim=0)
for pos, neg in zip(pos_proposals, neg_proposals)
])
# TODO: a more flexible way to configurate feat maps
roi_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois) x[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, bbox_pred = self.bbox_head(roi_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats)
loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels, bbox_targets = self.bbox_head.get_target(
label_weights, bbox_targets, sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn)
bbox_weights) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
*bbox_targets)
losses.update(loss_bbox) losses.update(loss_bbox)
# mask head forward and loss
if self.with_mask: if self.with_mask:
mask_targets = self.mask_head.get_mask_target( pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
pos_proposals, pos_assigned_gt_inds, gt_masks,
self.train_cfg.rcnn)
pos_rois = bbox2roi(pos_proposals)
mask_feats = self.mask_roi_extractor( mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], pos_rois) x[:self.mask_roi_extractor.num_inputs], pos_rois)
mask_pred = self.mask_head(mask_feats) mask_pred = self.mask_head(mask_feats)
mask_targets = self.mask_head.get_target(
sampling_results, gt_masks, self.train_cfg.rcnn)
pos_labels = torch.cat(
[res.pos_gt_labels for res in sampling_results])
loss_mask = self.mask_head.loss(mask_pred, mask_targets, loss_mask = self.mask_head.loss(mask_pred, mask_targets,
torch.cat(pos_gt_labels)) pos_labels)
losses.update(loss_mask) losses.update(loss_mask)
return losses return losses
...@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
x = self.extract_feat(img) x = self.extract_feat(img)
proposal_list = self.simple_test_rpn( proposal_list = self.simple_test_rpn(
x, img_meta, x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
self.test_cfg.rpn) if proposals is None else proposals
det_bboxes, det_labels = self.simple_test_bboxes( det_bboxes, det_labels = self.simple_test_bboxes(
x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
......
...@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module): ...@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module):
mask_pred = self.conv_logits(x) mask_pred = self.conv_logits(x)
return mask_pred return mask_pred
def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, def get_target(self, sampling_results, gt_masks, rcnn_train_cfg):
rcnn_train_cfg): pos_proposals = [res.pos_bboxes for res in sampling_results]
pos_assigned_gt_inds = [
res.pos_assigned_gt_inds for res in sampling_results
]
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, rcnn_train_cfg) gt_masks, rcnn_train_cfg)
return mask_targets return mask_targets
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment