diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py index a5f070f8bbe9a7293acf13fd89d480827eba3a69..dfeb3b407300beee4d530c111b4187693b190e2b 100644 --- a/mmdet/core/anchor/__init__.py +++ b/mmdet/core/anchor/__init__.py @@ -1,8 +1,10 @@ from .anchor_generator import AnchorGenerator from .anchor_target import anchor_inside_flags, anchor_target from .guided_anchor_target import ga_loc_target, ga_shape_target +from .point_generator import PointGenerator +from .point_target import point_target __all__ = [ 'AnchorGenerator', 'anchor_target', 'anchor_inside_flags', 'ga_loc_target', - 'ga_shape_target' + 'ga_shape_target', 'PointGenerator', 'point_target' ] diff --git a/mmdet/core/anchor/point_generator.py b/mmdet/core/anchor/point_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a34dddd7a76946cf8177f0aea529a29cfa4a78 --- /dev/null +++ b/mmdet/core/anchor/point_generator.py @@ -0,0 +1,34 @@ +import torch + + +class PointGenerator(object): + + def _meshgrid(self, x, y, row_major=True): + xx = x.repeat(len(y)) + yy = y.view(-1, 1).repeat(1, len(x)).view(-1) + if row_major: + return xx, yy + else: + return yy, xx + + def grid_points(self, featmap_size, stride=16, device='cuda'): + feat_h, feat_w = featmap_size + shift_x = torch.arange(0., feat_w, device=device) * stride + shift_y = torch.arange(0., feat_h, device=device) * stride + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + stride = shift_x.new_full((shift_xx.shape[0], ), stride) + shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, featmap_size, valid_size, device='cuda'): + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.uint8, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.uint8, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid diff --git a/mmdet/core/anchor/point_target.py b/mmdet/core/anchor/point_target.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab8d0260c93e479783fff9fbb02d680589ed28e --- /dev/null +++ b/mmdet/core/anchor/point_target.py @@ -0,0 +1,165 @@ +import torch + +from ..bbox import PseudoSampler, assign_and_sample, build_assigner +from ..utils import multi_apply + + +def point_target(proposals_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + cfg, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + sampling=True, + unmap_outputs=True): + """Compute corresponding GT box and classification targets for proposals. + + Args: + points_list (list[list]): Multi level points of each image. + valid_flag_list (list[list]): Multi level valid flags of each image. + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + cfg (dict): train sample configs. + + Returns: + tuple + """ + num_imgs = len(img_metas) + assert len(proposals_list) == len(valid_flag_list) == num_imgs + + # points number of multi levels + num_level_proposals = [points.size(0) for points in proposals_list[0]] + + # concat all level points and flags to a single tensor + for i in range(num_imgs): + assert len(proposals_list[i]) == len(valid_flag_list[i]) + proposals_list[i] = torch.cat(proposals_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + (all_labels, all_label_weights, all_bbox_gt, all_proposals, + all_proposal_weights, pos_inds_list, neg_inds_list) = multi_apply( + point_target_single, + proposals_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + cfg=cfg, + label_channels=label_channels, + sampling=sampling, + unmap_outputs=unmap_outputs) + # no valid points + if any([labels is None for labels in all_labels]): + return None + # sampled points of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + labels_list = images_to_levels(all_labels, num_level_proposals) + label_weights_list = images_to_levels(all_label_weights, + num_level_proposals) + bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) + proposals_list = images_to_levels(all_proposals, num_level_proposals) + proposal_weights_list = images_to_levels(all_proposal_weights, + num_level_proposals) + return (labels_list, label_weights_list, bbox_gt_list, proposals_list, + proposal_weights_list, num_total_pos, num_total_neg) + + +def images_to_levels(target, num_level_grids): + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + target = torch.stack(target, 0) + level_targets = [] + start = 0 + for n in num_level_grids: + end = start + n + level_targets.append(target[:, start:end].squeeze(0)) + start = end + return level_targets + + +def point_target_single(flat_proposals, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + cfg, + label_channels=1, + sampling=True, + unmap_outputs=True): + inside_flags = valid_flags + if not inside_flags.any(): + return (None, ) * 7 + # assign gt and sample proposals + proposals = flat_proposals[inside_flags, :] + + if sampling: + assign_result, sampling_result = assign_and_sample( + proposals, gt_bboxes, gt_bboxes_ignore, None, cfg) + else: + bbox_assigner = build_assigner(cfg.assigner) + assign_result = bbox_assigner.assign(proposals, gt_bboxes, + gt_bboxes_ignore, gt_labels) + bbox_sampler = PseudoSampler() + sampling_result = bbox_sampler.sample(assign_result, proposals, + gt_bboxes) + + num_valid_proposals = proposals.shape[0] + bbox_gt = proposals.new_zeros([num_valid_proposals, 4]) + pos_proposals = torch.zeros_like(proposals) + proposals_weights = proposals.new_zeros([num_valid_proposals, 4]) + labels = proposals.new_zeros(num_valid_proposals, dtype=torch.long) + label_weights = proposals.new_zeros(num_valid_proposals, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + pos_gt_bboxes = sampling_result.pos_gt_bboxes + bbox_gt[pos_inds, :] = pos_gt_bboxes + pos_proposals[pos_inds, :] = proposals[pos_inds, :] + proposals_weights[pos_inds, :] = 1.0 + if gt_labels is None: + labels[pos_inds] = 1 + else: + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + if cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of proposals + if unmap_outputs: + num_total_proposals = flat_proposals.size(0) + labels = unmap(labels, num_total_proposals, inside_flags) + label_weights = unmap(label_weights, num_total_proposals, inside_flags) + bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) + pos_proposals = unmap(pos_proposals, num_total_proposals, inside_flags) + proposals_weights = unmap(proposals_weights, num_total_proposals, + inside_flags) + + return (labels, label_weights, bbox_gt, pos_proposals, proposals_weights, + pos_inds, neg_inds) + + +def unmap(data, count, inds, fill=0): + """ Unmap a subset of item (data) back to the original set of items (of + size count) """ + if data.dim() == 1: + ret = data.new_full((count, ), fill) + ret[inds] = data + else: + new_size = (count, ) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds, :] = data + return ret diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 594e8406b5dad0ef381a9dd9d2ec9fbb75e0efd7..93eebb775be7720f232f122050d5f753117f7731 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -2,7 +2,9 @@ from .approx_max_iou_assigner import ApproxMaxIoUAssigner from .assign_result import AssignResult from .base_assigner import BaseAssigner from .max_iou_assigner import MaxIoUAssigner +from .point_assigner import PointAssigner __all__ = [ - 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult' + 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', + 'PointAssigner' ] diff --git a/mmdet/core/bbox/assigners/point_assigner.py b/mmdet/core/bbox/assigners/point_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..fe81e7d57e0a00ebbd732638927d629c4e87960a --- /dev/null +++ b/mmdet/core/bbox/assigners/point_assigner.py @@ -0,0 +1,116 @@ +import torch + +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +class PointAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each point. + + Each proposals will be assigned with `0`, or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + """ + + def __init__(self, scale=4, pos_num=3): + self.scale = scale + self.pos_num = pos_num + + def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): + """Assign gt to points. + + This method assign a gt bbox to every points set, each points set + will be assigned with 0, or a positive number. + 0 means negative sample, positive number is the index (1-based) of + assigned gt. + The assignment is done in following steps, the order matters. + + 1. assign every points to 0 + 2. A point is assigned to some gt bbox if + (i) the point is within the k closest points to the gt bbox + (ii) the distance between this point and the gt is smaller than + other gt bboxes + + Args: + points (Tensor): points to be assigned, shape(n, 3) while last + dimension stands for (x, y, stride). + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + """ + if points.shape[0] == 0 or gt_bboxes.shape[0] == 0: + raise ValueError('No gt or bboxes') + points_xy = points[:, :2] + points_stride = points[:, 2] + points_lvl = torch.log2( + points_stride).int() # [3...,4...,5...,6...,7...] + lvl_min, lvl_max = points_lvl.min(), points_lvl.max() + num_gts, num_points = gt_bboxes.shape[0], points.shape[0] + + # assign gt box + gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 + gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) + scale = self.scale + gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + + torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() + gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) + + # stores the assigned gt index of each point + assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) + # stores the assigned gt dist (to this point) of each point + assigned_gt_dist = points.new_full((num_points, ), float('inf')) + points_range = torch.arange(points.shape[0]) + + for idx in range(num_gts): + gt_lvl = gt_bboxes_lvl[idx] + # get the index of points in this level + lvl_idx = gt_lvl == points_lvl + points_index = points_range[lvl_idx] + # get the points in this level + lvl_points = points_xy[lvl_idx, :] + # get the center point of gt + gt_point = gt_bboxes_xy[[idx], :] + # get width and height of gt + gt_wh = gt_bboxes_wh[[idx], :] + # compute the distance between gt center and + # all points in this level + points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) + # find the nearest k points to gt center in this level + min_dist, min_dist_index = torch.topk( + points_gt_dist, self.pos_num, largest=False) + # the index of nearest k points to gt center in this level + min_dist_points_index = points_index[min_dist_index] + # The less_than_recorded_index stores the index + # of min_dist that is less then the assigned_gt_dist. Where + # assigned_gt_dist stores the dist from previous assigned gt + # (if exist) to each point. + less_than_recorded_index = min_dist < assigned_gt_dist[ + min_dist_points_index] + # The min_dist_points_index stores the index of points satisfy: + # (1) it is k nearest to current gt center in this level. + # (2) it is closer to current gt center than other gt center. + min_dist_points_index = min_dist_points_index[ + less_than_recorded_index] + # assign the result + assigned_gt_inds[min_dist_points_index] = idx + 1 + assigned_gt_dist[min_dist_points_index] = min_dist[ + less_than_recorded_index] + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_zeros((num_points, )) + pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels)