diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py index 61304307415392b17f78dd82f53c731dcca1a144..ec27c4c02c919af8103440966a69c723291ea427 100644 --- a/mmdet/core/anchor/anchor_generator.py +++ b/mmdet/core/anchor/anchor_generator.py @@ -2,6 +2,17 @@ import torch class AnchorGenerator(object): + """ + Examples: + >>> from mmdet.core import AnchorGenerator + >>> self = AnchorGenerator(9, [1.], [1.]) + >>> all_anchors = self.grid_anchors((2, 2), device='cpu') + >>> print(all_anchors) + tensor([[ 0., 0., 8., 8.], + [16., 0., 24., 8.], + [ 0., 16., 8., 24.], + [16., 16., 24., 24.]]) + """ def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): self.base_size = base_size diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 580b9bdfb24d43e80600ad44a70cde6b8ccc58e3..b9d1e660546abfde71b1c8010ef2b5cb4c380957 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -37,6 +37,44 @@ def delta2bbox(rois, stds=[1, 1, 1, 1], max_shape=None, wh_ratio_clip=16 / 1000): + """ + Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas are + network outputs used to shift/scale those boxes. + + Args: + rois (Tensor): boxes to be transformed. Has shape (N, 4) + deltas (Tensor): encoded offsets with respect to each roi. + Has shape (N, 4). Note N = num_anchors * W * H when rois is a grid + of anchors. Offset encoding follows [1]_. + means (list): denormalizing means for delta coordinates + stds (list): denormalizing standard deviation for delta coordinates + max_shape (tuple[int, int]): maximum bounds for boxes. specifies (H, W) + wh_ratio_clip (float): maximum aspect ratio for boxes. + + Returns: + Tensor: boxes with shape (N, 4), where columns represent + tl_x, tl_y, br_x, br_y. + + References: + .. [1] https://arxiv.org/abs/1311.2524 + + Example: + >>> rois = torch.Tensor([[ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 0., 0., 1., 1.], + >>> [ 5., 5., 5., 5.]]) + >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], + >>> [ 1., 1., 1., 1.], + >>> [ 0., 0., 2., -1.], + >>> [ 0.7, -1.9, -0.5, 0.3]]) + >>> delta2bbox(rois, deltas, max_shape=(32, 32)) + tensor([[0.0000, 0.0000, 1.0000, 1.0000], + [0.2817, 0.2817, 4.7183, 4.7183], + [0.0000, 0.6321, 7.3891, 0.3679], + [5.8967, 2.9251, 5.5033, 3.2749]]) + """ means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) denorm_deltas = deltas * stds + means @@ -47,14 +85,19 @@ def delta2bbox(rois, max_ratio = np.abs(np.log(wh_ratio_clip)) dw = dw.clamp(min=-max_ratio, max=max_ratio) dh = dh.clamp(min=-max_ratio, max=max_ratio) + # Compute center of each roi px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) + # Compute width/height of each roi pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw) ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh) + # Use exp(network energy) to enlarge/shrink each roi gw = pw * dw.exp() gh = ph * dh.exp() + # Use network energy to shift the center of each roi gx = torch.addcmul(px, 1, pw, dx) # gx = px + pw * dx gy = torch.addcmul(py, 1, ph, dy) # gy = py + ph * dy + # Convert center-xy/width/height to top-left, bottom-right x1 = gx - gw * 0.5 + 0.5 y1 = gy - gh * 0.5 + 0.5 x2 = gx + gw * 0.5 - 0.5 diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py index 01beecd43abc7641d43773b192f988a55f4295d9..ce3794c6450347d9e0b469066584ccfcdef6f8d0 100644 --- a/mmdet/core/post_processing/bbox_nms.py +++ b/mmdet/core/post_processing/bbox_nms.py @@ -13,7 +13,8 @@ def multiclass_nms(multi_bboxes, Args: multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) - multi_scores (Tensor): shape (n, #class) + multi_scores (Tensor): shape (n, #class), where the 0th column + contains scores of the background class, but this will be ignored. score_thr (float): bbox threshold, bboxes with scores lower than it will not be considered. nms_thr (float): NMS IoU threshold diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py index 55c25de9e11c027dd6be213cf4512b7acf1a67ee..d526b3b3beb37b3177c2bb25f3bfcaaa7fe887d4 100644 --- a/mmdet/models/anchor_heads/anchor_head.py +++ b/mmdet/models/anchor_heads/anchor_head.py @@ -16,6 +16,8 @@ class AnchorHead(nn.Module): """Anchor-based head (RPN, RetinaNet, SSD, etc.). Args: + num_classes (int): Number of categories including the background + category. in_channels (int): Number of channels in the input feature map. feat_channels (int): Number of channels of the feature map. anchor_scales (Iterable): Anchor scales. @@ -45,6 +47,7 @@ class AnchorHead(nn.Module): loss_bbox=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)): super(AnchorHead, self).__init__() + # NOTE: in_channels is only used in child classes (e.g. RetinaHead) self.in_channels = in_channels self.num_classes = num_classes self.feat_channels = feat_channels @@ -62,6 +65,10 @@ class AnchorHead(nn.Module): self.cls_out_channels = num_classes - 1 else: self.cls_out_channels = num_classes + + if self.cls_out_channels <= 0: + raise ValueError('num_classes={} is too small'.format(num_classes)) + self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) self.fp16_enabled = False @@ -202,6 +209,46 @@ class AnchorHead(nn.Module): @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg, rescale=False): + """ + Transform network output for a batch into labeled boxes. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + img_metas (list[dict]): size / scale info for each image + cfg (mmcv.Config): test / postprocessing configuration + rescale (bool): if True, return boxes in original image space + + Returns: + list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the class index of the + corresponding box. + + Example: + >>> import mmcv + >>> self = AnchorHead(num_classes=9, in_channels=1, + >>> feat_channels=1) + >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}] + >>> cfg = mmcv.Config(dict( + >>> score_thr=0.00, + >>> nms=dict(type='nms', iou_thr=1.0), + >>> max_per_img=10)) + >>> feat = torch.rand(1, 1, 3, 3) + >>> cls_score, bbox_pred = self.forward_single(feat) + >>> # note the input lists are over different levels, not images + >>> cls_scores, bbox_preds = [cls_score], [bbox_pred] + >>> result_list = self.get_bboxes(cls_scores, bbox_preds, + >>> img_metas, cfg) + >>> det_bboxes, det_labels = result_list[0] + >>> assert len(result_list) == 1 + >>> assert det_bboxes.shape[1] == 5 + >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img + """ assert len(cls_scores) == len(bbox_preds) num_levels = len(cls_scores) @@ -229,18 +276,21 @@ class AnchorHead(nn.Module): return result_list def get_bboxes_single(self, - cls_scores, - bbox_preds, + cls_score_list, + bbox_pred_list, mlvl_anchors, img_shape, scale_factor, cfg, rescale=False): - assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + """ + Transform outputs for a single batch item into labeled boxes. + """ + assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) mlvl_bboxes = [] mlvl_scores = [] - for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds, - mlvl_anchors): + for cls_score, bbox_pred, anchors in zip(cls_score_list, + bbox_pred_list, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels) @@ -251,6 +301,7 @@ class AnchorHead(nn.Module): bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: + # Get maximum scores for foreground classes. if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=1) else: @@ -268,6 +319,7 @@ class AnchorHead(nn.Module): mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) mlvl_scores = torch.cat(mlvl_scores) if self.use_sigmoid_cls: + # Add a dummy background class to the front when using sigmoid padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,