From bac11303e2776b57b730d4ffb5d41bc5ce34b079 Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Wed, 17 Oct 2018 22:33:17 +0800 Subject: [PATCH] add BBoxAssigner and BBoxSampler --- configs/fast_mask_rcnn_r50_fpn_1x.py | 22 +- configs/fast_rcnn_r50_fpn_1x.py | 22 +- configs/faster_rcnn_r50_fpn_1x.py | 43 +- configs/mask_rcnn_r50_fpn_1x.py | 43 +- configs/rpn_r50_fpn_1x.py | 21 +- mmdet/core/anchor/anchor_target.py | 27 +- mmdet/core/bbox/__init__.py | 15 +- mmdet/core/bbox/assignment.py | 155 +++++++ mmdet/core/bbox/bbox_target.py | 50 +-- mmdet/core/bbox/sampling.py | 496 +++++++++-------------- mmdet/datasets/coco.py | 12 +- mmdet/models/bbox_heads/bbox_head.py | 12 +- mmdet/models/detectors/two_stage.py | 65 ++- mmdet/models/mask_heads/fcn_mask_head.py | 7 +- 14 files changed, 522 insertions(+), 468 deletions(-) create mode 100644 mmdet/core/bbox/assignment.py diff --git a/configs/fast_mask_rcnn_r50_fpn_1x.py b/configs/fast_mask_rcnn_r50_fpn_1x.py index af2070f..22215da 100644 --- a/configs/fast_mask_rcnn_r50_fpn_1x.py +++ b/configs/fast_mask_rcnn_r50_fpn_1x.py @@ -43,17 +43,19 @@ model = dict( # model training and testing settings train_cfg = dict( rcnn=dict( + assigner=dict( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), mask_size=28, - pos_iou_thr=0.5, - neg_iou_thr=0.5, - crowd_thr=1.1, - roi_batch_size=512, - add_gt_as_proposals=True, - pos_fraction=0.25, - pos_balance_sampling=False, - neg_pos_ub=512, - neg_balance_thr=0, - min_pos_iou=0.5, pos_weight=-1, debug=False)) test_cfg = dict( diff --git a/configs/fast_rcnn_r50_fpn_1x.py b/configs/fast_rcnn_r50_fpn_1x.py index 397ab43..27de2ff 100644 --- a/configs/fast_rcnn_r50_fpn_1x.py +++ b/configs/fast_rcnn_r50_fpn_1x.py @@ -32,16 +32,18 @@ model = dict( # model training and testing settings train_cfg = dict( rcnn=dict( - pos_iou_thr=0.5, - neg_iou_thr=0.5, - crowd_thr=1.1, - roi_batch_size=512, - add_gt_as_proposals=True, - pos_fraction=0.25, - pos_balance_sampling=False, - neg_pos_ub=512, - neg_balance_thr=0, - min_pos_iou=0.5, + assigner=dict( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), pos_weight=-1, debug=False)) test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5)) diff --git a/configs/faster_rcnn_r50_fpn_1x.py b/configs/faster_rcnn_r50_fpn_1x.py index 1c06c4c..4ab8b5a 100644 --- a/configs/faster_rcnn_r50_fpn_1x.py +++ b/configs/faster_rcnn_r50_fpn_1x.py @@ -42,30 +42,35 @@ model = dict( # model training and testing settings train_cfg = dict( rpn=dict( - pos_fraction=0.5, - pos_balance_sampling=False, - neg_pos_ub=256, + assigner=dict( + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False, + pos_balance_sampling=False, + neg_balance_thr=0), allowed_border=0, - crowd_thr=1.1, - anchor_batch_size=256, - pos_iou_thr=0.7, - neg_iou_thr=0.3, - neg_balance_thr=0, - min_pos_iou=0.3, pos_weight=-1, smoothl1_beta=1 / 9.0, debug=False), rcnn=dict( - pos_iou_thr=0.5, - neg_iou_thr=0.5, - crowd_thr=1.1, - roi_batch_size=512, - add_gt_as_proposals=True, - pos_fraction=0.25, - pos_balance_sampling=False, - neg_pos_ub=512, - neg_balance_thr=0, - min_pos_iou=0.5, + assigner=dict( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), pos_weight=-1, debug=False)) test_cfg = dict( diff --git a/configs/mask_rcnn_r50_fpn_1x.py b/configs/mask_rcnn_r50_fpn_1x.py index 8868cf6..b190ad8 100644 --- a/configs/mask_rcnn_r50_fpn_1x.py +++ b/configs/mask_rcnn_r50_fpn_1x.py @@ -53,31 +53,36 @@ model = dict( # model training and testing settings train_cfg = dict( rpn=dict( - pos_fraction=0.5, - pos_balance_sampling=False, - neg_pos_ub=256, + assigner=dict( + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False, + pos_balance_sampling=False, + neg_balance_thr=0), allowed_border=0, - crowd_thr=1.1, - anchor_batch_size=256, - pos_iou_thr=0.7, - neg_iou_thr=0.3, - neg_balance_thr=0, - min_pos_iou=0.3, pos_weight=-1, smoothl1_beta=1 / 9.0, debug=False), rcnn=dict( + assigner=dict( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), mask_size=28, - pos_iou_thr=0.5, - neg_iou_thr=0.5, - crowd_thr=1.1, - roi_batch_size=512, - add_gt_as_proposals=True, - pos_fraction=0.25, - pos_balance_sampling=False, - neg_pos_ub=512, - neg_balance_thr=0, - min_pos_iou=0.5, pos_weight=-1, debug=False)) test_cfg = dict( diff --git a/configs/rpn_r50_fpn_1x.py b/configs/rpn_r50_fpn_1x.py index 7f1b6d0..8e2b402 100644 --- a/configs/rpn_r50_fpn_1x.py +++ b/configs/rpn_r50_fpn_1x.py @@ -27,16 +27,19 @@ model = dict( # model training and testing settings train_cfg = dict( rpn=dict( - pos_fraction=0.5, - pos_balance_sampling=False, - neg_pos_ub=256, + assigner=dict( + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False, + pos_balance_sampling=False, + neg_balance_thr=0), allowed_border=0, - crowd_thr=1.1, - anchor_batch_size=256, - pos_iou_thr=0.7, - neg_iou_thr=0.3, - neg_balance_thr=0, - min_pos_iou=0.3, pos_weight=-1, smoothl1_beta=1 / 9.0, debug=False)) diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py index ad81e39..49047b9 100644 --- a/mmdet/core/anchor/anchor_target.py +++ b/mmdet/core/anchor/anchor_target.py @@ -1,6 +1,6 @@ import torch -from ..bbox import bbox_assign, bbox2delta, bbox_sampling +from ..bbox import assign_and_sample, bbox2delta from ..utils import multi_apply @@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, return (None, ) * 6 # assign gt and sample anchors anchors = flat_anchors[inside_flags, :] - assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign( - anchors, - gt_bboxes, - pos_iou_thr=cfg.pos_iou_thr, - neg_iou_thr=cfg.neg_iou_thr, - min_pos_iou=cfg.min_pos_iou) - pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size, - cfg.pos_fraction, cfg.neg_pos_ub, - cfg.pos_balance_sampling, max_overlaps, - cfg.neg_balance_thr) + _, sampling_result = assign_and_sample(anchors, gt_bboxes, None, None, cfg) + num_valid_anchors = anchors.shape[0] bbox_targets = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors) - labels = torch.zeros_like(assigned_gt_inds) - label_weights = torch.zeros_like(assigned_gt_inds, dtype=anchors.dtype) + labels = anchors.new_zeros((num_valid_anchors, )) + label_weights = anchors.new_zeros((num_valid_anchors, )) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: - pos_anchors = anchors[pos_inds, :] - pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :] - pos_bbox_targets = bbox2delta(pos_anchors, pos_gt_bbox, target_means, - target_stds) + pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes, + sampling_result.pos_gt_bboxes, + target_means, target_stds) bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 labels[pos_inds] = 1 diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index a5c21dc..2ed869f 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -1,15 +1,14 @@ from .geometry import bbox_overlaps -from .sampling import (random_choice, bbox_assign, bbox_assign_wrt_overlaps, - bbox_sampling, bbox_sampling_pos, bbox_sampling_neg, - sample_bboxes) +from .assignment import BBoxAssigner, AssignResult +from .sampling import (BBoxSampler, SamplingResult, assign_and_sample, + random_choice) from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox, bbox2result) from .bbox_target import bbox_target __all__ = [ - 'bbox_overlaps', 'random_choice', 'bbox_assign', - 'bbox_assign_wrt_overlaps', 'bbox_sampling', 'bbox_sampling_pos', - 'bbox_sampling_neg', 'sample_bboxes', 'bbox2delta', 'delta2bbox', - 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', - 'bbox2result', 'bbox_target' + 'bbox_overlaps', 'BBoxAssigner', 'AssignResult', 'BBoxSampler', + 'SamplingResult', 'assign_and_sample', 'random_choice', 'bbox2delta', + 'delta2bbox', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', + 'roi2bbox', 'bbox2result', 'bbox_target' ] diff --git a/mmdet/core/bbox/assignment.py b/mmdet/core/bbox/assignment.py new file mode 100644 index 0000000..62233af --- /dev/null +++ b/mmdet/core/bbox/assignment.py @@ -0,0 +1,155 @@ +import torch + +from .geometry import bbox_overlaps + + +class BBoxAssigner(object): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, `0`, or a positive integer + indicating the ground truth index. + + - -1: don't care + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + pos_iou_thr (float): IoU threshold for positive bboxes. + neg_iou_thr (float or tuple): IoU threshold for negative bboxes. + min_pos_iou (float): Minimum iou for a bbox to be considered as a + positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN, + it is usually set as pos_iou_thr + ignore_iof_thr (float): IoF threshold for ignoring bboxes (if + `gt_bboxes_ignore` is specified). Negative values mean not + ignoring any bboxes. + """ + + def __init__(self, + pos_iou_thr, + neg_iou_thr, + min_pos_iou=.0, + ignore_iof_thr=-1): + self.pos_iou_thr = pos_iou_thr + self.neg_iou_thr = neg_iou_thr + self.min_pos_iou = min_pos_iou + self.ignore_iof_thr = ignore_iof_thr + + def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): + """Assign gt to bboxes. + + This method assign a gt bbox to every bbox (proposal/anchor), each bbox + will be assigned with -1, 0, or a positive number. -1 means don't care, + 0 means negative sample, positive number is the index (1-based) of + assigned gt. + The assignment is done in following steps, the order matters. + + 1. assign every bbox to -1 + 2. assign proposals whose iou with all gts < neg_iou_thr to 0 + 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, + assign it to that bbox + 4. for each gt bbox, assign its nearest proposals (may be more than + one) to itself + + Args: + bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0: + raise ValueError('No gt or bboxes') + bboxes = bboxes[:, :4] + overlaps = bbox_overlaps(bboxes, gt_bboxes) + + if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and ( + gt_bboxes_ignore.numel() > 0): + ignore_overlaps = bbox_overlaps( + bboxes, gt_bboxes_ignore, mode='iof') + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + ignore_bboxes_inds = torch.nonzero( + ignore_max_overlaps > self.ignore_iof_thr).squeeze() + if ignore_bboxes_inds.numel() > 0: + overlaps[ignore_bboxes_inds[:, 0], :] = -1 + + assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) + return assign_result + + def assign_wrt_overlaps(self, overlaps, gt_labels=None): + """Assign w.r.t. the overlaps of bboxes with gts. + + Args: + overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes, + shape(n, k). + gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + if overlaps.numel() == 0: + raise ValueError('No gt or proposals') + + num_bboxes, num_gts = overlaps.size(0), overlaps.size(1) + + # 1. assign -1 by default + assigned_gt_inds = overlaps.new_full( + (num_bboxes, ), -1, dtype=torch.long) + + assert overlaps.size() == (num_bboxes, num_gts) + # for each anchor, which gt best overlaps with it + # for each anchor, the max iou of all gts + max_overlaps, argmax_overlaps = overlaps.max(dim=1) + # for each gt, which anchor best overlaps with it + # for each gt, the max iou of all proposals + gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0) + + # 2. assign negative: below + if isinstance(self.neg_iou_thr, float): + assigned_gt_inds[(max_overlaps >= 0) + & (max_overlaps < self.neg_iou_thr)] = 0 + elif isinstance(self.neg_iou_thr, tuple): + assert len(self.neg_iou_thr) == 2 + assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0]) + & (max_overlaps < self.neg_iou_thr[1])] = 0 + + # 3. assign positive: above positive IoU threshold + pos_inds = max_overlaps >= self.pos_iou_thr + assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 + + # 4. assign fg: for each gt, proposals with highest IoU + for i in range(num_gts): + if gt_max_overlaps[i] >= self.min_pos_iou: + assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1 + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, )) + pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + + return AssignResult( + num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels) + + +class AssignResult(object): + + def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.max_overlaps = max_overlaps + self.labels = labels + + def add_gt_(self, gt_labels): + self_inds = torch.arange( + 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) + self.gt_inds = torch.cat([self_inds, self.gt_inds]) + self.max_overlaps = torch.cat( + [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) + if self.labels is not None: + self.labels = torch.cat([gt_labels, self.labels]) diff --git a/mmdet/core/bbox/bbox_target.py b/mmdet/core/bbox/bbox_target.py index 2e205c3..4a0450d 100644 --- a/mmdet/core/bbox/bbox_target.py +++ b/mmdet/core/bbox/bbox_target.py @@ -4,23 +4,23 @@ from .transforms import bbox2delta from ..utils import multi_apply -def bbox_target(pos_proposals_list, - neg_proposals_list, +def bbox_target(pos_bboxes_list, + neg_bboxes_list, pos_gt_bboxes_list, pos_gt_labels_list, cfg, - reg_num_classes=1, + reg_classes=1, target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0], concat=True): labels, label_weights, bbox_targets, bbox_weights = multi_apply( - proposal_target_single, - pos_proposals_list, - neg_proposals_list, + bbox_target_single, + pos_bboxes_list, + neg_bboxes_list, pos_gt_bboxes_list, pos_gt_labels_list, cfg=cfg, - reg_num_classes=reg_num_classes, + reg_classes=reg_classes, target_means=target_means, target_stds=target_stds) @@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list, return labels, label_weights, bbox_targets, bbox_weights -def proposal_target_single(pos_proposals, - neg_proposals, - pos_gt_bboxes, - pos_gt_labels, - cfg, - reg_num_classes=1, - target_means=[.0, .0, .0, .0], - target_stds=[1.0, 1.0, 1.0, 1.0]): - num_pos = pos_proposals.size(0) - num_neg = neg_proposals.size(0) +def bbox_target_single(pos_bboxes, + neg_bboxes, + pos_gt_bboxes, + pos_gt_labels, + cfg, + reg_classes=1, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]): + num_pos = pos_bboxes.size(0) + num_neg = neg_bboxes.size(0) num_samples = num_pos + num_neg - labels = pos_proposals.new_zeros(num_samples, dtype=torch.long) - label_weights = pos_proposals.new_zeros(num_samples) - bbox_targets = pos_proposals.new_zeros(num_samples, 4) - bbox_weights = pos_proposals.new_zeros(num_samples, 4) + labels = pos_bboxes.new_zeros(num_samples, dtype=torch.long) + label_weights = pos_bboxes.new_zeros(num_samples) + bbox_targets = pos_bboxes.new_zeros(num_samples, 4) + bbox_weights = pos_bboxes.new_zeros(num_samples, 4) if num_pos > 0: labels[:num_pos] = pos_gt_labels pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight label_weights[:num_pos] = pos_weight - pos_bbox_targets = bbox2delta(pos_proposals, pos_gt_bboxes, - target_means, target_stds) + pos_bbox_targets = bbox2delta(pos_bboxes, pos_gt_bboxes, target_means, + target_stds) bbox_targets[:num_pos, :] = pos_bbox_targets bbox_weights[:num_pos, :] = 1 if num_neg > 0: label_weights[-num_neg:] = 1.0 - if reg_num_classes > 1: + if reg_classes > 1: bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights, - labels, reg_num_classes) + labels, reg_classes) return labels, label_weights, bbox_targets, bbox_weights diff --git a/mmdet/core/bbox/sampling.py b/mmdet/core/bbox/sampling.py index 976cd95..63d2279 100644 --- a/mmdet/core/bbox/sampling.py +++ b/mmdet/core/bbox/sampling.py @@ -1,7 +1,7 @@ import numpy as np import torch -from .geometry import bbox_overlaps +from .assignment import BBoxAssigner def random_choice(gallery, num): @@ -21,323 +21,207 @@ def random_choice(gallery, num): return gallery[rand_inds] -def bbox_assign(proposals, - gt_bboxes, - gt_bboxes_ignore=None, - gt_labels=None, - pos_iou_thr=0.5, - neg_iou_thr=0.5, - min_pos_iou=.0, - crowd_thr=-1): - """Assign a corresponding gt bbox or background to each proposal/anchor. +def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg): + bbox_assigner = BBoxAssigner(**cfg.assigner) + bbox_sampler = BBoxSampler(**cfg.sampler) + assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore, + gt_labels) + sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes, + gt_labels) + return assign_result, sampling_result - Each proposals will be assigned with `-1`, `0`, or a positive integer. - - -1: don't care - - 0: negative sample, no assigned gt - - positive integer: positive sample, index (1-based) of assigned gt - - If `gt_bboxes_ignore` is specified, bboxes which have iof (intersection - over foreground) with `gt_bboxes_ignore` above `crowd_thr` will be ignored. - - Args: - proposals (Tensor): Proposals or RPN anchors, shape (n, 4). - gt_bboxes (Tensor): Ground truth bboxes, shape (k, 4). - gt_bboxes_ignore (Tensor, optional): shape(m, 4). - gt_labels (Tensor, optional): shape (k, ). - pos_iou_thr (float): IoU threshold for positive bboxes. - neg_iou_thr (float or tuple): IoU threshold for negative bboxes. - min_pos_iou (float): Minimum iou for a bbox to be considered as a - positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN, - it is usually set as pos_iou_thr - crowd_thr (float): IoF threshold for ignoring bboxes. Negative value - for not ignoring any bboxes. - - Returns: - tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, ) - """ - - # calculate overlaps between the proposals and the gt boxes - overlaps = bbox_overlaps(proposals, gt_bboxes) - if overlaps.numel() == 0: - raise ValueError('No gt bbox or proposals') - - # ignore proposals according to crowd bboxes - if (crowd_thr > 0) and (gt_bboxes_ignore is - not None) and (gt_bboxes_ignore.numel() > 0): - crowd_overlaps = bbox_overlaps(proposals, gt_bboxes_ignore, mode='iof') - crowd_max_overlaps, _ = crowd_overlaps.max(dim=1) - crowd_bboxes_inds = torch.nonzero( - crowd_max_overlaps > crowd_thr).long() - if crowd_bboxes_inds.numel() > 0: - overlaps[crowd_bboxes_inds, :] = -1 - - return bbox_assign_wrt_overlaps(overlaps, gt_labels, pos_iou_thr, - neg_iou_thr, min_pos_iou) - - -def bbox_assign_wrt_overlaps(overlaps, - gt_labels=None, - pos_iou_thr=0.5, - neg_iou_thr=0.5, - min_pos_iou=.0): - """Assign a corresponding gt bbox or background to each proposal/anchor. - - This method assign a gt bbox to every proposal, each proposals will be - assigned with -1, 0, or a positive number. -1 means don't care, 0 means - negative sample, positive number is the index (1-based) of assigned gt. - The assignment is done in following steps, the order matters: - - 1. assign every anchor to -1 - 2. assign proposals whose iou with all gts < neg_iou_thr to 0 - 3. for each anchor, if the iou with its nearest gt >= pos_iou_thr, - assign it to that bbox - 4. for each gt bbox, assign its nearest proposals(may be more than one) - to itself - - Args: - overlaps (Tensor): Overlaps between n proposals and k gt_bboxes, - shape(n, k). - gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ). - pos_iou_thr (float): IoU threshold for positive bboxes. - neg_iou_thr (float or tuple): IoU threshold for negative bboxes. - min_pos_iou (float): Minimum IoU for a bbox to be considered as a - positive bbox. This argument only affects the 4th step. - - Returns: - tuple: (assigned_gt_inds, [assigned_labels], argmax_overlaps, - max_overlaps), shape (n, ) - """ - num_bboxes, num_gts = overlaps.size(0), overlaps.size(1) - # 1. assign -1 by default - assigned_gt_inds = overlaps.new(num_bboxes).long().fill_(-1) - - if overlaps.numel() == 0: - raise ValueError('No gt bbox or proposals') - - assert overlaps.size() == (num_bboxes, num_gts) - # for each anchor, which gt best overlaps with it - # for each anchor, the max iou of all gts - max_overlaps, argmax_overlaps = overlaps.max(dim=1) - # for each gt, which anchor best overlaps with it - # for each gt, the max iou of all proposals - gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0) - - # 2. assign negative: below - if isinstance(neg_iou_thr, float): - assigned_gt_inds[(max_overlaps >= 0) - & (max_overlaps < neg_iou_thr)] = 0 - elif isinstance(neg_iou_thr, tuple): - assert len(neg_iou_thr) == 2 - assigned_gt_inds[(max_overlaps >= neg_iou_thr[0]) - & (max_overlaps < neg_iou_thr[1])] = 0 - - # 3. assign positive: above positive IoU threshold - pos_inds = max_overlaps >= pos_iou_thr - assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 - - # 4. assign fg: for each gt, proposals with highest IoU - for i in range(num_gts): - if gt_max_overlaps[i] >= min_pos_iou: - assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1 - - if gt_labels is None: - return assigned_gt_inds, argmax_overlaps, max_overlaps - else: - assigned_labels = assigned_gt_inds.new(num_bboxes).fill_(0) - pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() - if pos_inds.numel() > 0: - assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - - 1] - return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps - - -def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True): - """Balance sampling for positive bboxes/anchors. - - 1. calculate average positive num for each gt: num_per_gt - 2. sample at most num_per_gt positives for each gt - 3. random sampling from rest anchors if not enough fg - """ - pos_inds = torch.nonzero(assigned_gt_inds > 0) - if pos_inds.numel() != 0: - pos_inds = pos_inds.squeeze(1) - if pos_inds.numel() <= num_expected: - return pos_inds - elif not balance_sampling: - return random_choice(pos_inds, num_expected) - else: - unique_gt_inds = torch.unique(assigned_gt_inds[pos_inds].cpu()) - num_gts = len(unique_gt_inds) - num_per_gt = int(round(num_expected / float(num_gts)) + 1) - sampled_inds = [] - for i in unique_gt_inds: - inds = torch.nonzero(assigned_gt_inds == i.item()) - if inds.numel() != 0: - inds = inds.squeeze(1) - else: - continue - if len(inds) > num_per_gt: - inds = random_choice(inds, num_per_gt) - sampled_inds.append(inds) - sampled_inds = torch.cat(sampled_inds) - if len(sampled_inds) < num_expected: - num_extra = num_expected - len(sampled_inds) - extra_inds = np.array( - list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) - if len(extra_inds) > num_extra: - extra_inds = random_choice(extra_inds, num_extra) - extra_inds = torch.from_numpy(extra_inds).to( - assigned_gt_inds.device).long() - sampled_inds = torch.cat([sampled_inds, extra_inds]) - elif len(sampled_inds) > num_expected: - sampled_inds = random_choice(sampled_inds, num_expected) - return sampled_inds - - -def bbox_sampling_neg(assigned_gt_inds, - num_expected, - max_overlaps=None, - balance_thr=0, - hard_fraction=0.5): - """Balance sampling for negative bboxes/anchors. - - Negative samples are split into 2 set: hard (balance_thr <= iou < - neg_iou_thr) and easy(iou < balance_thr). The sampling ratio is controlled - by `hard_fraction`. - """ - neg_inds = torch.nonzero(assigned_gt_inds == 0) - if neg_inds.numel() != 0: - neg_inds = neg_inds.squeeze(1) - if len(neg_inds) <= num_expected: - return neg_inds - elif balance_thr <= 0: - # uniform sampling among all negative samples - return random_choice(neg_inds, num_expected) - else: - assert max_overlaps is not None - max_overlaps = max_overlaps.cpu().numpy() - # balance sampling for negative samples - neg_set = set(neg_inds.cpu().numpy()) - easy_set = set( - np.where( - np.logical_and(max_overlaps >= 0, - max_overlaps < balance_thr))[0]) - hard_set = set(np.where(max_overlaps >= balance_thr)[0]) - easy_neg_inds = list(easy_set & neg_set) - hard_neg_inds = list(hard_set & neg_set) - - num_expected_hard = int(num_expected * hard_fraction) - if len(hard_neg_inds) > num_expected_hard: - sampled_hard_inds = random_choice(hard_neg_inds, num_expected_hard) - else: - sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int) - num_expected_easy = num_expected - len(sampled_hard_inds) - if len(easy_neg_inds) > num_expected_easy: - sampled_easy_inds = random_choice(easy_neg_inds, num_expected_easy) - else: - sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int) - sampled_inds = np.concatenate((sampled_easy_inds, sampled_hard_inds)) - if len(sampled_inds) < num_expected: - num_extra = num_expected - len(sampled_inds) - extra_inds = np.array(list(neg_set - set(sampled_inds))) - if len(extra_inds) > num_extra: - extra_inds = random_choice(extra_inds, num_extra) - sampled_inds = np.concatenate((sampled_inds, extra_inds)) - sampled_inds = torch.from_numpy(sampled_inds).long().to( - assigned_gt_inds.device) - return sampled_inds - - -def bbox_sampling(assigned_gt_inds, - num_expected, - pos_fraction, - neg_pos_ub, - pos_balance_sampling=True, - max_overlaps=None, - neg_balance_thr=0, - neg_hard_fraction=0.5): +class BBoxSampler(object): """Sample positive and negative bboxes given assigned results. Args: - assigned_gt_inds (Tensor): Assigned gt indices for each bbox. - num_expected (int): Expected total samples (pos and neg). pos_fraction (float): Positive sample fraction. neg_pos_ub (float): Negative/Positive upper bound. - pos_balance_sampling(bool): Whether to sample positive samples around + pos_balance_sampling (bool): Whether to sample positive samples around each gt bbox evenly. - max_overlaps (Tensor, optional): For each bbox, the max IoU of all gts. - Used for negative balance sampling only. neg_balance_thr (float, optional): IoU threshold for simple/hard negative balance sampling. neg_hard_fraction (float, optional): Fraction of hard negative samples for negative balance sampling. - - Returns: - tuple[Tensor]: positive bbox indices, negative bbox indices. - """ - num_expected_pos = int(num_expected * pos_fraction) - pos_inds = bbox_sampling_pos(assigned_gt_inds, num_expected_pos, - pos_balance_sampling) - # We found that sampled indices have duplicated items occasionally. - # (mab be a bug of PyTorch) - pos_inds = pos_inds.unique() - num_sampled_pos = pos_inds.numel() - num_neg_max = int( - neg_pos_ub * - num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub) - num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos) - neg_inds = bbox_sampling_neg(assigned_gt_inds, num_expected_neg, - max_overlaps, neg_balance_thr, - neg_hard_fraction) - neg_inds = neg_inds.unique() - return pos_inds, neg_inds - - -def sample_bboxes(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg): - """Sample positive and negative bboxes. - - This is a simple implementation of bbox sampling given candidates and - ground truth bboxes, which includes 3 steps. - - 1. Assign gt to each bbox. - 2. Add gt bboxes to the sampling pool (optional). - 3. Perform positive and negative sampling. - - Args: - bboxes (Tensor): Boxes to be sampled from. - gt_bboxes (Tensor): Ground truth bboxes. - gt_bboxes_ignore (Tensor): Ignored ground truth bboxes. In MS COCO, - `crowd` bboxes are considered as ignored. - gt_labels (Tensor): Class labels of ground truth bboxes. - cfg (dict): Sampling configs. - - Returns: - tuple[Tensor]: pos_bboxes, neg_bboxes, pos_assigned_gt_inds, - pos_gt_bboxes, pos_gt_labels """ - bboxes = bboxes[:, :4] - assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \ - bbox_assign(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, - cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou, - cfg.crowd_thr) - - if cfg.add_gt_as_proposals: - bboxes = torch.cat([gt_bboxes, bboxes], dim=0) - gt_assign_self = torch.arange( - 1, len(gt_labels) + 1, dtype=torch.long, device=bboxes.device) - assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds]) - assigned_labels = torch.cat([gt_labels, assigned_labels]) - pos_inds, neg_inds = bbox_sampling( - assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub, - cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr) - - pos_bboxes = bboxes[pos_inds] - neg_bboxes = bboxes[neg_inds] - pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1 - pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :] - pos_gt_labels = assigned_labels[pos_inds] + def __init__(self, + num, + pos_fraction, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0, + neg_hard_fraction=0.5): + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_balance_sampling = pos_balance_sampling + self.neg_balance_thr = neg_balance_thr + self.neg_hard_fraction = neg_hard_fraction + + def _sample_pos(self, assign_result, num_expected): + """Balance sampling for positive bboxes/anchors. + + 1. calculate average positive num for each gt: num_per_gt + 2. sample at most num_per_gt positives for each gt + 3. random sampling from rest anchors if not enough fg + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0) + if pos_inds.numel() != 0: + pos_inds = pos_inds.squeeze(1) + if pos_inds.numel() <= num_expected: + return pos_inds + elif not self.pos_balance_sampling: + return random_choice(pos_inds, num_expected) + else: + unique_gt_inds = torch.unique( + assign_result.gt_inds[pos_inds].cpu()) + num_gts = len(unique_gt_inds) + num_per_gt = int(round(num_expected / float(num_gts)) + 1) + sampled_inds = [] + for i in unique_gt_inds: + inds = torch.nonzero(assign_result.gt_inds == i.item()) + if inds.numel() != 0: + inds = inds.squeeze(1) + else: + continue + if len(inds) > num_per_gt: + inds = random_choice(inds, num_per_gt) + sampled_inds.append(inds) + sampled_inds = torch.cat(sampled_inds) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array( + list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) + if len(extra_inds) > num_extra: + extra_inds = random_choice(extra_inds, num_extra) + extra_inds = torch.from_numpy(extra_inds).to( + assign_result.gt_inds.device).long() + sampled_inds = torch.cat([sampled_inds, extra_inds]) + elif len(sampled_inds) > num_expected: + sampled_inds = random_choice(sampled_inds, num_expected) + return sampled_inds + + def _sample_neg(self, assign_result, num_expected): + """Balance sampling for negative bboxes/anchors. + + Negative samples are split into 2 set: hard (balance_thr <= iou < + neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is + controlled by `hard_fraction`. + """ + neg_inds = torch.nonzero(assign_result.gt_inds == 0) + if neg_inds.numel() != 0: + neg_inds = neg_inds.squeeze(1) + if len(neg_inds) <= num_expected: + return neg_inds + elif self.neg_balance_thr <= 0: + # uniform sampling among all negative samples + return random_choice(neg_inds, num_expected) + else: + max_overlaps = assign_result.max_overlaps.cpu().numpy() + # balance sampling for negative samples + neg_set = set(neg_inds.cpu().numpy()) + easy_set = set( + np.where( + np.logical_and(max_overlaps >= 0, + max_overlaps < self.neg_balance_thr))[0]) + hard_set = set(np.where(max_overlaps >= self.neg_balance_thr)[0]) + easy_neg_inds = list(easy_set & neg_set) + hard_neg_inds = list(hard_set & neg_set) + + num_expected_hard = int(num_expected * self.neg_hard_fraction) + if len(hard_neg_inds) > num_expected_hard: + sampled_hard_inds = random_choice(hard_neg_inds, + num_expected_hard) + else: + sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int) + num_expected_easy = num_expected - len(sampled_hard_inds) + if len(easy_neg_inds) > num_expected_easy: + sampled_easy_inds = random_choice(easy_neg_inds, + num_expected_easy) + else: + sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int) + sampled_inds = np.concatenate((sampled_easy_inds, + sampled_hard_inds)) + if len(sampled_inds) < num_expected: + num_extra = num_expected - len(sampled_inds) + extra_inds = np.array(list(neg_set - set(sampled_inds))) + if len(extra_inds) > num_extra: + extra_inds = random_choice(extra_inds, num_extra) + sampled_inds = np.concatenate((sampled_inds, extra_inds)) + sampled_inds = torch.from_numpy(sampled_inds).long().to( + assign_result.gt_inds.device) + return sampled_inds + + def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None): + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + 1. Assign gt to each bbox. + 2. Add gt bboxes to the sampling pool (optional). + 3. Perform positive and negative sampling. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + bboxes (Tensor): Boxes to be sampled from. + gt_bboxes (Tensor): Ground truth bboxes. + gt_labels (Tensor, optional): Class labels of ground truth bboxes. + + Returns: + :obj:`SamplingResult`: Sampling result. + """ + bboxes = bboxes[:, :4] + + gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) + if self.add_gt_as_proposals: + bboxes = torch.cat([gt_bboxes, bboxes], dim=0) + assign_result.add_gt_(gt_labels) + gt_flags = torch.cat([ + bboxes.new_ones((gt_bboxes.shape[0], ), dtype=torch.uint8), + gt_flags + ]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self._sample_pos(assign_result, num_expected_pos) + # We found that sampled indices have duplicated items occasionally. + # (mab be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + num_neg_max = int(self.neg_pos_ub * + num_sampled_pos) if num_sampled_pos > 0 else int( + self.neg_pos_ub) + num_expected_neg = min(num_neg_max, num_expected_neg) + neg_inds = self._sample_neg(assign_result, num_expected_neg) + neg_inds = neg_inds.unique() + + return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, + assign_result, gt_flags) + + +class SamplingResult(object): + + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, + gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :] + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None - return (pos_bboxes, neg_bboxes, pos_assigned_gt_inds, pos_gt_bboxes, - pos_gt_labels) + @property + def bboxes(self): + return torch.cat([self.pos_bboxes, self.neg_bboxes]) diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 0ee92f9..fbb14aa 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -215,7 +215,7 @@ class CocoDataset(Dataset): 'proposals should have shapes (n, 4) or (n, 5), ' 'but found {}'.format(proposals.shape)) if proposals.shape[1] == 5: - scores = proposals[:, 4] + scores = proposals[:, 4, None] proposals = proposals[:, :4] else: scores = None @@ -237,8 +237,8 @@ class CocoDataset(Dataset): if self.proposals is not None: proposals = self.bbox_transform(proposals, img_shape, scale_factor, flip) - proposals = np.hstack([proposals, scores[:, None] - ]) if scores is not None else proposals + proposals = np.hstack( + [proposals, scores]) if scores is not None else proposals gt_bboxes = self.bbox_transform(gt_bboxes, img_shape, scale_factor, flip) gt_bboxes_ignore = self.bbox_transform(gt_bboxes_ignore, img_shape, @@ -295,14 +295,14 @@ class CocoDataset(Dataset): flip=flip) if proposal is not None: if proposal.shape[1] == 5: - score = proposal[:, 4] + score = proposal[:, 4, None] proposal = proposal[:, :4] else: score = None _proposal = self.bbox_transform(proposal, img_shape, scale_factor, flip) - _proposal = np.hstack([_proposal, score[:, None] - ]) if score is not None else _proposal + _proposal = np.hstack( + [_proposal, score]) if score is not None else _proposal _proposal = to_tensor(_proposal) else: _proposal = None diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index 67dba03..9b423bd 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -59,16 +59,20 @@ class BBoxHead(nn.Module): bbox_pred = self.fc_reg(x) if self.with_reg else None return cls_score, bbox_pred - def get_bbox_target(self, pos_proposals, neg_proposals, pos_gt_bboxes, - pos_gt_labels, rcnn_train_cfg): - reg_num_classes = 1 if self.reg_class_agnostic else self.num_classes + def get_target(self, sampling_results, gt_bboxes, gt_labels, + rcnn_train_cfg): + pos_proposals = [res.pos_bboxes for res in sampling_results] + neg_proposals = [res.neg_bboxes for res in sampling_results] + pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results] + pos_gt_labels = [res.pos_gt_labels for res in sampling_results] + reg_classes = 1 if self.reg_class_agnostic else self.num_classes cls_reg_targets = bbox_target( pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, rcnn_train_cfg, - reg_num_classes, + reg_classes, target_means=self.target_means, target_stds=self.target_stds) return cls_reg_targets diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index 48a818d..064ee0e 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -4,7 +4,7 @@ import torch.nn as nn from .base import BaseDetector from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .. import builder -from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply +from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply) class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, @@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, gt_labels, gt_masks=None, proposals=None): - losses = dict() - x = self.extract_feat(img) + losses = dict() + + # RPN forward and loss if self.with_rpn: rpn_outs = self.rpn_head(x) rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, @@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, else: proposal_list = proposals + # assign gts and sample proposals + if self.with_bbox or self.with_mask: + assign_results, sampling_results = multi_apply( + assign_and_sample, + proposal_list, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + cfg=self.train_cfg.rcnn) + + # bbox head forward and loss if self.with_bbox: - (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes, - pos_gt_labels) = multi_apply( - sample_bboxes, - proposal_list, - gt_bboxes, - gt_bboxes_ignore, - gt_labels, - cfg=self.train_cfg.rcnn) - (labels, label_weights, bbox_targets, - bbox_weights) = self.bbox_head.get_bbox_target( - pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels, - self.train_cfg.rcnn) - - rois = bbox2roi([ - torch.cat([pos, neg], dim=0) - for pos, neg in zip(pos_proposals, neg_proposals) - ]) - # TODO: a more flexible way to configurate feat maps - roi_feats = self.bbox_roi_extractor( + rois = bbox2roi([res.bboxes for res in sampling_results]) + # TODO: a more flexible way to decide which feature maps to use + bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) - cls_score, bbox_pred = self.bbox_head(roi_feats) + cls_score, bbox_pred = self.bbox_head(bbox_feats) - loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, labels, - label_weights, bbox_targets, - bbox_weights) + bbox_targets = self.bbox_head.get_target( + sampling_results, gt_bboxes, gt_labels, self.train_cfg.rcnn) + loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, + *bbox_targets) losses.update(loss_bbox) + # mask head forward and loss if self.with_mask: - mask_targets = self.mask_head.get_mask_target( - pos_proposals, pos_assigned_gt_inds, gt_masks, - self.train_cfg.rcnn) - pos_rois = bbox2roi(pos_proposals) + pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], pos_rois) mask_pred = self.mask_head(mask_feats) + + mask_targets = self.mask_head.get_target( + sampling_results, gt_masks, self.train_cfg.rcnn) + pos_labels = torch.cat( + [res.pos_gt_labels for res in sampling_results]) loss_mask = self.mask_head.loss(mask_pred, mask_targets, - torch.cat(pos_gt_labels)) + pos_labels) losses.update(loss_mask) return losses @@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, x = self.extract_feat(img) proposal_list = self.simple_test_rpn( - x, img_meta, - self.test_cfg.rpn) if proposals is None else proposals + x, img_meta, self.test_cfg.rpn) if proposals is None else proposals det_bboxes, det_labels = self.simple_test_bboxes( x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py index ba46bea..2c90a03 100644 --- a/mmdet/models/mask_heads/fcn_mask_head.py +++ b/mmdet/models/mask_heads/fcn_mask_head.py @@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module): mask_pred = self.conv_logits(x) return mask_pred - def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, - rcnn_train_cfg): + def get_target(self, sampling_results, gt_masks, rcnn_train_cfg): + pos_proposals = [res.pos_bboxes for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, gt_masks, rcnn_train_cfg) return mask_targets -- GitLab