diff --git a/mmdet/core/rpn_ops/anchor_target.py b/mmdet/core/rpn_ops/anchor_target.py index 6062633c0c29ec8346b55adf342e3bd4d9e9ce70..3cf651b5c46e98da6673e58c0299c6f7a90e632f 100644 --- a/mmdet/core/rpn_ops/anchor_target.py +++ b/mmdet/core/rpn_ops/anchor_target.py @@ -1,93 +1,85 @@ import torch -import numpy as np -from ..bbox_ops import (bbox_assign, bbox_transform, bbox_sampling) +from ..bbox_ops import bbox_assign, bbox_transform, bbox_sampling +from ..utils import multi_apply -def anchor_target(anchor_list, valid_flag_list, featmap_sizes, gt_bboxes_list, - img_metas, target_means, target_stds, cfg): - """Compute regression and classification targets for anchors. - There may be multiple feature levels, +def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas, + target_means, target_stds, cfg): + """Compute regression and classification targets for anchors. Args: - anchor_list(list): anchors of each feature map level - featmap_sizes(list): feature map sizes - gt_bboxes_list(list): ground truth bbox of images in a mini-batch - img_shapes(list): shape of each image in a mini-batch - cfg(dict): configs + anchor_list (list[list]): Multi level anchors of each image. + valid_flag_list (list[list]): Multi level valid flags of each image. + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + target_means (Iterable): Mean value of regression targets. + target_stds (Iterable): Std value of regression targets. + cfg (dict): RPN train configs. Returns: tuple """ num_imgs = len(img_metas) - num_levels = len(featmap_sizes) - if len(anchor_list) == num_levels: - all_anchors = torch.cat(anchor_list, 0) - anchor_nums = [anchors.size(0) for anchors in anchor_list] - use_isomerism_anchors = False - elif len(anchor_list) == num_imgs: - # using different anchors for different images - all_anchors_list = [ - torch.cat(anchor_list[img_id], 0) for img_id in range(num_imgs) - ] - anchor_nums = [anchors.size(0) for anchors in anchor_list[0]] - use_isomerism_anchors = True - else: - raise ValueError('length of anchor_list should be equal to number of ' - 'feature lvls or number of images in a batch') - all_labels = [] - all_label_weights = [] - all_bbox_targets = [] - all_bbox_weights = [] - num_total_sampled = 0 - for img_id in range(num_imgs): - if isinstance(valid_flag_list[img_id], list): - valid_flags = torch.cat(valid_flag_list[img_id], 0) - else: - valid_flags = valid_flag_list[img_id] - if use_isomerism_anchors: - all_anchors = all_anchors_list[img_id] - inside_flags = anchor_inside_flags(all_anchors, valid_flags, - img_metas[img_id]['img_shape'][:2], - cfg.allowed_border) - if not inside_flags.any(): - return None - gt_bboxes = gt_bboxes_list[img_id] - anchor_targets = anchor_target_single(all_anchors, inside_flags, - gt_bboxes, target_means, - target_stds, cfg) - (labels, label_weights, bbox_targets, bbox_weights, pos_inds, - neg_inds) = anchor_targets - all_labels.append(labels) - all_label_weights.append(label_weights) - all_bbox_targets.append(bbox_targets) - all_bbox_weights.append(bbox_weights) - num_total_sampled += max(pos_inds.numel() + neg_inds.numel(), 1) - all_labels = torch.stack(all_labels, 0) - all_label_weights = torch.stack(all_label_weights, 0) - all_bbox_targets = torch.stack(all_bbox_targets, 0) - all_bbox_weights = torch.stack(all_bbox_weights, 0) - # split into different feature levels - labels_list = [] - label_weights_list = [] - bbox_targets_list = [] - bbox_weights_list = [] + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + means_replicas = [target_means for _ in range(num_imgs)] + stds_replicas = [target_stds for _ in range(num_imgs)] + cfg_replicas = [cfg for _ in range(num_imgs)] + (all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( + anchor_target_single, anchor_list, valid_flag_list, gt_bboxes_list, + img_metas, means_replicas, stds_replicas, cfg_replicas) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_samples = sum([ + max(pos_inds.numel() + neg_inds.numel(), 1) + for pos_inds, neg_inds in zip(pos_inds_list, neg_inds_list) + ]) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_samples) + + +def images_to_levels(target, num_level_anchors): + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + target = torch.stack(target, 0) + level_targets = [] start = 0 - for anchor_num in anchor_nums: - end = start + anchor_num - labels_list.append(all_labels[:, start:end].squeeze(0)) - label_weights_list.append(all_label_weights[:, start:end].squeeze(0)) - bbox_targets_list.append(all_bbox_targets[:, start:end].squeeze(0)) - bbox_weights_list.append(all_bbox_weights[:, start:end].squeeze(0)) + for n in num_level_anchors: + end = start + n + level_targets.append(target[:, start:end].squeeze(0)) start = end - return (labels_list, label_weights_list, bbox_targets_list, - bbox_weights_list, num_total_sampled) + return level_targets -def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means, - target_stds, cfg): +def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta, + target_means, target_stds, cfg): + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 6 # assign gt and sample anchors - anchors = all_anchors[inside_flags, :] + anchors = flat_anchors[inside_flags, :] assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign( anchors, gt_bboxes, @@ -120,7 +112,7 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means, label_weights[neg_inds] = 1.0 # map up to original set of anchors - num_total_anchors = all_anchors.size(0) + num_total_anchors = flat_anchors.size(0) labels = unmap(labels, num_total_anchors, inside_flags) label_weights = unmap(label_weights, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) @@ -130,27 +122,20 @@ def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means, neg_inds) -def anchor_inside_flags(all_anchors, valid_flags, img_shape, allowed_border=0): +def anchor_inside_flags(flat_anchors, valid_flags, img_shape, + allowed_border=0): img_h, img_w = img_shape[:2] if allowed_border >= 0: inside_flags = valid_flags & \ - (all_anchors[:, 0] >= -allowed_border) & \ - (all_anchors[:, 1] >= -allowed_border) & \ - (all_anchors[:, 2] < img_w + allowed_border) & \ - (all_anchors[:, 3] < img_h + allowed_border) + (flat_anchors[:, 0] >= -allowed_border) & \ + (flat_anchors[:, 1] >= -allowed_border) & \ + (flat_anchors[:, 2] < img_w + allowed_border) & \ + (flat_anchors[:, 3] < img_h + allowed_border) else: inside_flags = valid_flags return inside_flags -def unique(tensor): - if tensor.is_cuda: - u_tensor = np.unique(tensor.cpu().numpy()) - return tensor.new_tensor(u_tensor) - else: - return torch.unique(tensor) - - def unmap(data, count, inds, fill=0): """ Unmap a subset of item (data) back to the original set of items (of size count) """ diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 8e7d9feffb93b7c0f85071bcccf19e63273acedb..b803360072b94f129876f11abaa87ed997fb61a7 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -212,7 +212,7 @@ class CocoDataset(Dataset): # apply transforms flip = True if np.random.rand() < self.flip_ratio else False img_scale = random_scale(self.img_scales) # sample a scale - img, img_shape, scale_factor = self.img_transform( + img, img_shape, pad_shape, scale_factor = self.img_transform( img, img_scale, flip) if self.proposals is not None: proposals = self.bbox_transform(proposals, img_shape, @@ -232,6 +232,7 @@ class CocoDataset(Dataset): img_meta = dict( ori_shape=ori_shape, img_shape=img_shape, + pad_shape=pad_shape, scale_factor=scale_factor, flip=flip) @@ -260,12 +261,13 @@ class CocoDataset(Dataset): if self.proposals is not None else None) def prepare_single(img, scale, flip, proposal=None): - _img, img_shape, scale_factor = self.img_transform( + _img, img_shape, pad_shape, scale_factor = self.img_transform( img, scale, flip) _img = to_tensor(_img) _img_meta = dict( ori_shape=(img_info['height'], img_info['width'], 3), img_shape=img_shape, + pad_shape=pad_shape, scale_factor=scale_factor, flip=flip) if proposal is not None: diff --git a/mmdet/datasets/transforms.py b/mmdet/datasets/transforms.py index 3a41e8d4cd4842d6102d982b8d0e77f9151779f5..6cdba4e972e67f717ddf879a56258d8fda8adde9 100644 --- a/mmdet/datasets/transforms.py +++ b/mmdet/datasets/transforms.py @@ -36,8 +36,11 @@ class ImageTransform(object): img = mmcv.imflip(img) if self.size_divisor is not None: img = mmcv.impad_to_multiple(img, self.size_divisor) + pad_shape = img.shape + else: + pad_shape = img_shape img = img.transpose(2, 0, 1) - return img, img_shape, scale_factor + return img, img_shape, pad_shape, scale_factor def bbox_flip(bboxes, img_shape): diff --git a/mmdet/models/rpn_heads/rpn_head.py b/mmdet/models/rpn_heads/rpn_head.py index e81f19310e8e7e23b5e23be04888e511a7bd897d..68a81833e099f508ae4f776e62558fd3afb3e1d9 100644 --- a/mmdet/models/rpn_heads/rpn_head.py +++ b/mmdet/models/rpn_heads/rpn_head.py @@ -6,18 +6,35 @@ import torch.nn as nn import torch.nn.functional as F from mmdet.core import (AnchorGenerator, anchor_target, bbox_transform_inv, - weighted_cross_entropy, weighted_smoothl1, + multi_apply, weighted_cross_entropy, weighted_smoothl1, weighted_binary_cross_entropy) from mmdet.ops import nms -from ..utils import multi_apply, normal_init +from ..utils import normal_init class RPNHead(nn.Module): + """Network head of RPN. + + / - rpn_cls (1x1 conv) + input - rpn_conv (3x3 conv) - + \ - rpn_reg (1x1 conv) + + Args: + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of channels for the RPN feature map. + anchor_scales (Iterable): Anchor scales. + anchor_ratios (Iterable): Anchor aspect ratios. + anchor_strides (Iterable): Anchor strides. + anchor_base_sizes (Iterable): Anchor base sizes. + target_means (Iterable): Mean values of regression targets. + target_stds (Iterable): Std values of regression targets. + use_sigmoid_cls (bool): Whether to use sigmoid loss for classification. + (softmax by default) + """ def __init__(self, in_channels, - feat_channels=512, - coarsest_stride=32, + feat_channels=256, anchor_scales=[8, 16, 32], anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[4, 8, 16, 32, 64], @@ -28,7 +45,6 @@ class RPNHead(nn.Module): super(RPNHead, self).__init__() self.in_channels = in_channels self.feat_channels = feat_channels - self.coarsest_stride = coarsest_stride self.anchor_scales = anchor_scales self.anchor_ratios = anchor_ratios self.anchor_strides = anchor_strides @@ -66,38 +82,42 @@ class RPNHead(nn.Module): return multi_apply(self.forward_single, feats) def get_anchors(self, featmap_sizes, img_metas): - """Get anchors given a list of feature map sizes, and get valid flags - at the same time. (Extra padding regions should be marked as invalid) + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + + Returns: + tuple: anchors of each image, valid flags of each image """ - # calculate actual image shapes - padded_img_shapes = [] - for img_meta in img_metas: - h, w = img_meta['img_shape'][:2] - padded_h = int( - np.ceil(h / self.coarsest_stride) * self.coarsest_stride) - padded_w = int( - np.ceil(w / self.coarsest_stride) * self.coarsest_stride) - padded_img_shapes.append((padded_h, padded_w)) - # generate anchors for different feature levels - # len = feature levels - anchor_list = [] - # len = imgs per gpu - valid_flag_list = [[] for _ in range(len(img_metas))] - for i in range(len(featmap_sizes)): - anchor_stride = self.anchor_strides[i] + num_imgs = len(img_metas) + num_levels = len(featmap_sizes) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + multi_level_anchors = [] + for i in range(num_levels): anchors = self.anchor_generators[i].grid_anchors( - featmap_sizes[i], anchor_stride) - anchor_list.append(anchors) - # for each image in this feature level, get valid flags - featmap_size = featmap_sizes[i] - for img_id, (h, w) in enumerate(padded_img_shapes): - valid_feat_h = min( - int(np.ceil(h / anchor_stride)), featmap_size[0]) - valid_feat_w = min( - int(np.ceil(w / anchor_stride)), featmap_size[1]) + featmap_sizes[i], self.anchor_strides[i]) + multi_level_anchors.append(anchors) + anchor_list = [multi_level_anchors for _ in range(num_imgs)] + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for img_id, img_meta in enumerate(img_metas): + multi_level_flags = [] + for i in range(num_levels): + anchor_stride = self.anchor_strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w, _ = img_meta['pad_shape'] + valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) + valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) flags = self.anchor_generators[i].valid_flags( - featmap_size, (valid_feat_h, valid_feat_w)) - valid_flag_list[img_id].append(flags) + (feat_h, feat_w), (valid_feat_h, valid_feat_w)) + multi_level_flags.append(flags) + valid_flag_list.append(multi_level_flags) + return anchor_list, valid_flag_list def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights, @@ -135,7 +155,7 @@ class RPNHead(nn.Module): anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_shapes) cls_reg_targets = anchor_target( - anchor_list, valid_flag_list, featmap_sizes, gt_bboxes, img_shapes, + anchor_list, valid_flag_list, gt_bboxes, img_shapes, self.target_means, self.target_stds, cfg) if cls_reg_targets is None: return None diff --git a/tools/configs/r50_fpn_frcnn_1x.py b/tools/configs/r50_fpn_frcnn_1x.py index 71505ae6e9302f3693020c9c944a1b01617e8161..4ce93e623e3d66aaf7c1520bc818e6e33d29e184 100644 --- a/tools/configs/r50_fpn_frcnn_1x.py +++ b/tools/configs/r50_fpn_frcnn_1x.py @@ -18,7 +18,6 @@ model = dict( type='RPNHead', in_channels=256, feat_channels=256, - coarsest_stride=32, anchor_scales=[8], anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[4, 8, 16, 32, 64], diff --git a/tools/configs/r50_fpn_maskrcnn_1x.py b/tools/configs/r50_fpn_maskrcnn_1x.py index e6b353585f79b20d05b2a3c66d61c66fabf0a4cf..931f051b356c4f01119d20d397ceecd4d3babcf5 100644 --- a/tools/configs/r50_fpn_maskrcnn_1x.py +++ b/tools/configs/r50_fpn_maskrcnn_1x.py @@ -18,7 +18,6 @@ model = dict( type='RPNHead', in_channels=256, feat_channels=256, - coarsest_stride=32, anchor_scales=[8], anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[4, 8, 16, 32, 64], diff --git a/tools/configs/r50_fpn_rpn_1x.py b/tools/configs/r50_fpn_rpn_1x.py index c982f0402b39e7bc7fcfb38d91fdb47ad5ffb17f..a00cab9de8013455f498d96cb12c9801aafcc343 100644 --- a/tools/configs/r50_fpn_rpn_1x.py +++ b/tools/configs/r50_fpn_rpn_1x.py @@ -18,7 +18,6 @@ model = dict( type='RPNHead', in_channels=256, feat_channels=256, - coarsest_stride=32, anchor_scales=[8], anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[4, 8, 16, 32, 64], @@ -104,5 +103,5 @@ dist_params = dict(backend='gloo') log_level = 'INFO' work_dir = './work_dirs/fpn_rpn_r50_1x' load_from = None -resume_from = None +resume_from = None workflow = [('train', 1)]