From b69667001f250a54a37129a000a8d5160e047239 Mon Sep 17 00:00:00 2001 From: Jon Crall <erotemic@gmail.com> Date: Tue, 24 Dec 2019 04:02:28 -0500 Subject: [PATCH] Allow for images to contain zero true detections (#1531) * Allow for images to contain zero true detections * Allow for empty assignment in PointAssigner * Allow ApproxMaxIouAssigner to return an empty result * Fix CascadeRNN forward when entire batch has no truth * Correctly assign boxes to background when there is no truth * Fix assignment tests * Make flatten robust * Fix bbox loss with empty pred/truth * Fix logic error in BBoxHead.loss * Add tests for empty truth cases * tests faster rcnn empty forward * Skip roipool forward tests if torchvision is not installed * Add tests for bbox/anchor heads * Consolidate test_forward and test_forward2 * Fix assign_results.labels = None when gt_labels is given; Add test for this case * Fix OHEM Sampler with zero truth * remove xdev * resolve 3 reviews * Fix flake8 * refactoring * fix yaml format * add filter flag * minor fix * delete redundant code in load anno * fix flake8 errors * quick fix for empty truth with masks * fix yapf error * fix mask padding for empty masks Co-authored-by: Cao Yuhang <yhcao6@gmail.com> Co-authored-by: Kai Chen <chenkaidev@gmail.com> --- .../bbox/assigners/approx_max_iou_assigner.py | 14 +- mmdet/core/bbox/assigners/assign_result.py | 75 ++++- mmdet/core/bbox/assigners/max_iou_assigner.py | 31 ++- mmdet/core/bbox/assigners/point_assigner.py | 20 +- mmdet/core/bbox/geometry.py | 31 ++- mmdet/core/bbox/samplers/base_sampler.py | 2 +- mmdet/datasets/coco.py | 2 +- mmdet/datasets/custom.py | 6 +- mmdet/datasets/pipelines/loading.py | 15 +- mmdet/datasets/pipelines/transforms.py | 5 +- mmdet/models/bbox_heads/bbox_head.py | 40 +-- mmdet/models/bbox_heads/convfc_bbox_head.py | 8 +- mmdet/models/detectors/cascade_rcnn.py | 6 + mmdet/models/detectors/two_stage.py | 18 +- tests/test_assigner.py | 261 ++++++++++++++++++ tests/test_forward.py | 158 +++++++++++ tests/test_heads.py | 171 ++++++++++++ tests/test_sampler.py | 235 ++++++++++++++++ 18 files changed, 1032 insertions(+), 66 deletions(-) create mode 100644 tests/test_assigner.py create mode 100644 tests/test_heads.py create mode 100644 tests/test_sampler.py diff --git a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py index a52ed26..e7d3510 100644 --- a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py @@ -74,9 +74,9 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner): Args: approxs (Tensor): Bounding boxes to be assigned, - shape(approxs_per_octave*n, 4). + shape(approxs_per_octave*n, 4). squares (Tensor): Base Bounding boxes to be assigned, - shape(n, 4). + shape(n, 4). approxs_per_octave (int): number of approxs per octave gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are @@ -86,11 +86,15 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner): Returns: :obj:`AssignResult`: The assign result. """ - - if squares.shape[0] == 0 or gt_bboxes.shape[0] == 0: - raise ValueError('No gt or approxs') num_squares = squares.size(0) num_gts = gt_bboxes.size(0) + + if num_squares == 0 or num_gts == 0: + # No predictions and/or truth, return empty assignment + overlaps = approxs.new(num_gts, num_squares) + assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) + return assign_result + # re-organize anchors by approxs_per_octave x num_squares approxs = torch.transpose( approxs.view(num_squares, approxs_per_octave, 4), 0, diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py index 33c761d..38a24d7 100644 --- a/mmdet/core/bbox/assigners/assign_result.py +++ b/mmdet/core/bbox/assigners/assign_result.py @@ -2,6 +2,41 @@ import torch class AssignResult(object): + """ + Stores assignments between predicted and truth boxes. + + Attributes: + num_gts (int): the number of truth boxes considered when computing this + assignment + + gt_inds (LongTensor): for each predicted box indicates the 1-based + index of the assigned truth box. 0 means unassigned and -1 means + ignore. + + max_overlaps (FloatTensor): the iou between the predicted box and its + assigned truth box. + + labels (None | LongTensor): If specified, for each predicted box + indicates the category label of the assigned truth box. + + Example: + >>> # An assign result between 4 predicted boxes and 9 true boxes + >>> # where only two boxes were assigned. + >>> num_gts = 9 + >>> max_overlaps = torch.LongTensor([0, .5, .9, 0]) + >>> gt_inds = torch.LongTensor([-1, 1, 2, 0]) + >>> labels = torch.LongTensor([0, 3, 4, 0]) + >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels) + >>> print(str(self)) # xdoctest: +IGNORE_WANT + <AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,), + labels.shape=(4,))> + >>> # Force addition of gt labels (when adding gt as proposals) + >>> new_labels = torch.LongTensor([3, 4, 5]) + >>> self.add_gt_(new_labels) + >>> print(str(self)) # xdoctest: +IGNORE_WANT + <AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,), + labels.shape=(7,))> + """ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): self.num_gts = num_gts @@ -13,7 +48,45 @@ class AssignResult(object): self_inds = torch.arange( 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) self.gt_inds = torch.cat([self_inds, self.gt_inds]) + + # Was this a bug? + # self.max_overlaps = torch.cat( + # [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) + # IIUC, It seems like the correct code should be: self.max_overlaps = torch.cat( - [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) + [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps]) + if self.labels is not None: self.labels = torch.cat([gt_labels, self.labels]) + + def __nice__(self): + """ + Create a "nice" summary string describing this assign result + """ + parts = [] + parts.append('num_gts={!r}'.format(self.num_gts)) + if self.gt_inds is None: + parts.append('gt_inds={!r}'.format(self.gt_inds)) + else: + parts.append('gt_inds.shape={!r}'.format( + tuple(self.gt_inds.shape))) + if self.max_overlaps is None: + parts.append('max_overlaps={!r}'.format(self.max_overlaps)) + else: + parts.append('max_overlaps.shape={!r}'.format( + tuple(self.max_overlaps.shape))) + if self.labels is None: + parts.append('labels={!r}'.format(self.labels)) + else: + parts.append('labels.shape={!r}'.format(tuple(self.labels.shape))) + return ', '.join(parts) + + def __repr__(self): + nice = self.__nice__() + classname = self.__class__.__name__ + return '<{}({}) at {}>'.format(classname, nice, hex(id(self))) + + def __str__(self): + classname = self.__class__.__name__ + nice = self.__nice__() + return '<{}({})>'.format(classname, nice) diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py index 3eb8936..93ffc42 100644 --- a/mmdet/core/bbox/assigners/max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/max_iou_assigner.py @@ -74,9 +74,15 @@ class MaxIoUAssigner(BaseAssigner): Returns: :obj:`AssignResult`: The assign result. + + Example: + >>> self = MaxIoUAssigner(0.5, 0.5) + >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]]) + >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]]) + >>> assign_result = self.assign(bboxes, gt_bboxes) + >>> expected_gt_inds = torch.LongTensor([1, 0]) + >>> assert torch.all(assign_result.gt_inds == expected_gt_inds) """ - if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0: - raise ValueError('No gt or bboxes') assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( gt_bboxes.shape[0] > self.gpu_assign_thr) else False # compute overlap and assign gt on CPU when number of GT is large @@ -88,6 +94,7 @@ class MaxIoUAssigner(BaseAssigner): gt_bboxes_ignore = gt_bboxes_ignore.cpu() if gt_labels is not None: gt_labels = gt_labels.cpu() + bboxes = bboxes[:, :4] overlaps = bbox_overlaps(gt_bboxes, bboxes) @@ -122,9 +129,6 @@ class MaxIoUAssigner(BaseAssigner): Returns: :obj:`AssignResult`: The assign result. """ - if overlaps.numel() == 0: - raise ValueError('No gt or proposals') - num_gts, num_bboxes = overlaps.size(0), overlaps.size(1) # 1. assign -1 by default @@ -132,6 +136,23 @@ class MaxIoUAssigner(BaseAssigner): -1, dtype=torch.long) + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = overlaps.new_zeros((num_bboxes, )) + if num_gts == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = overlaps.new_zeros((num_bboxes, ), + dtype=torch.long) + return AssignResult( + num_gts, + assigned_gt_inds, + max_overlaps, + labels=assigned_labels) + # for each anchor, which gt best overlaps with it # for each anchor, the max iou of all gts max_overlaps, argmax_overlaps = overlaps.max(dim=0) diff --git a/mmdet/core/bbox/assigners/point_assigner.py b/mmdet/core/bbox/assigners/point_assigner.py index fe81e7d..263b309 100644 --- a/mmdet/core/bbox/assigners/point_assigner.py +++ b/mmdet/core/bbox/assigners/point_assigner.py @@ -40,19 +40,33 @@ class PointAssigner(BaseAssigner): gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are labelled as `ignored`, e.g., crowd boxes in COCO. + NOTE: currently unused. gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). Returns: :obj:`AssignResult`: The assign result. """ - if points.shape[0] == 0 or gt_bboxes.shape[0] == 0: - raise ValueError('No gt or bboxes') + num_points = points.shape[0] + num_gts = gt_bboxes.shape[0] + + if num_gts == 0 or num_points == 0: + # If no truth assign everything to the background + assigned_gt_inds = points.new_full((num_points, ), + 0, + dtype=torch.long) + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = points.new_zeros((num_points, ), + dtype=torch.long) + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + points_xy = points[:, :2] points_stride = points[:, 2] points_lvl = torch.log2( points_stride).int() # [3...,4...,5...,6...,7...] lvl_min, lvl_max = points_lvl.min(), points_lvl.max() - num_gts, num_points = gt_bboxes.shape[0], points.shape[0] # assign gt box gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 diff --git a/mmdet/core/bbox/geometry.py b/mmdet/core/bbox/geometry.py index 3bc8dae..ff7c5d4 100644 --- a/mmdet/core/bbox/geometry.py +++ b/mmdet/core/bbox/geometry.py @@ -9,14 +9,39 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): bboxes1 and bboxes2. Args: - bboxes1 (Tensor): shape (m, 4) - bboxes2 (Tensor): shape (n, 4), if is_aligned is ``True``, then m and n - must be equal. + bboxes1 (Tensor): shape (m, 4) in <x1, y1, x2, y2> format. + bboxes2 (Tensor): shape (n, 4) in <x1, y1, x2, y2> format. + If is_aligned is ``True``, then m and n must be equal. mode (str): "iou" (intersection over union) or iof (intersection over foreground). Returns: ious(Tensor): shape (m, n) if is_aligned == False else shape (m, 1) + + Example: + >>> bboxes1 = torch.FloatTensor([ + >>> [0, 0, 10, 10], + >>> [10, 10, 20, 20], + >>> [32, 32, 38, 42], + >>> ]) + >>> bboxes2 = torch.FloatTensor([ + >>> [0, 0, 10, 20], + >>> [0, 10, 10, 19], + >>> [10, 10, 20, 20], + >>> ]) + >>> bbox_overlaps(bboxes1, bboxes2) + tensor([[0.5238, 0.0500, 0.0041], + [0.0323, 0.0452, 1.0000], + [0.0000, 0.0000, 0.0000]]) + + Example: + >>> empty = torch.FloatTensor([]) + >>> nonempty = torch.FloatTensor([ + >>> [0, 0, 10, 9], + >>> ]) + >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) """ assert mode in ['iou', 'iof'] diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py index 12df013..a396a8d 100644 --- a/mmdet/core/bbox/samplers/base_sampler.py +++ b/mmdet/core/bbox/samplers/base_sampler.py @@ -51,7 +51,7 @@ class BaseSampler(metaclass=ABCMeta): bboxes = bboxes[:, :4] gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) - if self.add_gt_as_proposals: + if self.add_gt_as_proposals and len(gt_bboxes) > 0: bboxes = torch.cat([gt_bboxes, bboxes], dim=0) assign_result.add_gt_(gt_labels) gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 23c9120..d041532 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -49,7 +49,7 @@ class CocoDataset(CustomDataset): valid_inds = [] ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) for i, img_info in enumerate(self.img_infos): - if self.img_ids[i] not in ids_with_ann: + if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann: continue if min(img_info['width'], img_info['height']) >= min_size: valid_inds.append(i) diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 84a0191..d068543 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -40,13 +40,15 @@ class CustomDataset(Dataset): img_prefix='', seg_prefix=None, proposal_file=None, - test_mode=False): + test_mode=False, + filter_empty_gt=True): self.ann_file = ann_file self.data_root = data_root self.img_prefix = img_prefix self.seg_prefix = seg_prefix self.proposal_file = proposal_file self.test_mode = test_mode + self.filter_empty_gt = filter_empty_gt # join paths if data_root is specified if self.data_root is not None: @@ -66,7 +68,7 @@ class CustomDataset(Dataset): self.proposals = self.load_proposals(self.proposal_file) else: self.proposals = None - # filter images with no annotation during training + # filter images too small if not test_mode: valid_inds = self._filter_imgs() self.img_infos = [self.img_infos[i] for i in valid_inds] diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py index 6a9938c..9f3007e 100644 --- a/mmdet/datasets/pipelines/loading.py +++ b/mmdet/datasets/pipelines/loading.py @@ -1,5 +1,4 @@ import os.path as osp -import warnings import mmcv import numpy as np @@ -42,28 +41,16 @@ class LoadAnnotations(object): with_label=True, with_mask=False, with_seg=False, - poly2mask=True, - skip_img_without_anno=True): + poly2mask=True): self.with_bbox = with_bbox self.with_label = with_label self.with_mask = with_mask self.with_seg = with_seg self.poly2mask = poly2mask - self.skip_img_without_anno = skip_img_without_anno def _load_bboxes(self, results): ann_info = results['ann_info'] results['gt_bboxes'] = ann_info['bboxes'] - if len(results['gt_bboxes']) == 0 and self.skip_img_without_anno: - if results['img_prefix'] is not None: - file_path = osp.join(results['img_prefix'], - results['img_info']['filename']) - else: - file_path = results['img_info']['filename'] - warnings.warn( - 'Skip the image "{}" that has no valid gt bbox'.format( - file_path)) - return None gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) if gt_bboxes_ignore is not None: diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 760b3b1..dc38597 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -275,7 +275,10 @@ class Pad(object): mmcv.impad(mask, pad_shape, pad_val=self.pad_val) for mask in results[key] ] - results[key] = np.stack(padded_masks, axis=0) + if padded_masks: + results[key] = np.stack(padded_masks, axis=0) + else: + results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8) def __call__(self, results): self._pad_img(results) diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index 2e983ff..ced0ad1 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -107,26 +107,30 @@ class BBoxHead(nn.Module): losses = dict() if cls_score is not None: avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) - losses['loss_cls'] = self.loss_cls( - cls_score, - labels, - label_weights, - avg_factor=avg_factor, - reduction_override=reduction_override) - losses['acc'] = accuracy(cls_score, labels) + if cls_score.numel() > 0: + losses['loss_cls'] = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: pos_inds = labels > 0 - if self.reg_class_agnostic: - pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds] - else: - pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, - 4)[pos_inds, labels[pos_inds]] - losses['loss_bbox'] = self.loss_bbox( - pos_bbox_pred, - bbox_targets[pos_inds], - bbox_weights[pos_inds], - avg_factor=bbox_targets.size(0), - reduction_override=reduction_override) + if pos_inds.any(): + if self.reg_class_agnostic: + pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), + 4)[pos_inds] + else: + pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, + 4)[pos_inds, + labels[pos_inds]] + losses['loss_bbox'] = self.loss_bbox( + pos_bbox_pred, + bbox_targets[pos_inds], + bbox_weights[pos_inds], + avg_factor=bbox_targets.size(0), + reduction_override=reduction_override) return losses @force_fp32(apply_to=('cls_score', 'bbox_pred')) diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py index 777c455..f0f8977 100644 --- a/mmdet/models/bbox_heads/convfc_bbox_head.py +++ b/mmdet/models/bbox_heads/convfc_bbox_head.py @@ -138,7 +138,9 @@ class ConvFCBBoxHead(BBoxHead): if self.num_shared_fcs > 0: if self.with_avg_pool: x = self.avg_pool(x) - x = x.view(x.size(0), -1) + + x = x.flatten(1) + for fc in self.shared_fcs: x = self.relu(fc(x)) # separate branches @@ -150,7 +152,7 @@ class ConvFCBBoxHead(BBoxHead): if x_cls.dim() > 2: if self.with_avg_pool: x_cls = self.avg_pool(x_cls) - x_cls = x_cls.view(x_cls.size(0), -1) + x_cls = x_cls.flatten(1) for fc in self.cls_fcs: x_cls = self.relu(fc(x_cls)) @@ -159,7 +161,7 @@ class ConvFCBBoxHead(BBoxHead): if x_reg.dim() > 2: if self.with_avg_pool: x_reg = self.avg_pool(x_reg) - x_reg = x_reg.view(x_reg.size(0), -1) + x_reg = x_reg.flatten(1) for fc in self.reg_fcs: x_reg = self.relu(fc(x_reg)) diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index e79b189..4ab1e57 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -236,6 +236,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): bbox_head = self.bbox_head[i] rois = bbox2roi([res.bboxes for res in sampling_results]) + + if len(rois) == 0: + # If there are no predicted and/or truth boxes, then we cannot + # compute head / mask losses + continue + bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index 1558195..962e0cb 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -247,16 +247,16 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, dtype=torch.uint8)) pos_inds = torch.cat(pos_inds) mask_feats = bbox_feats[pos_inds] - mask_pred = self.mask_head(mask_feats) - mask_targets = self.mask_head.get_target(sampling_results, - gt_masks, - self.train_cfg.rcnn) - pos_labels = torch.cat( - [res.pos_gt_labels for res in sampling_results]) - loss_mask = self.mask_head.loss(mask_pred, mask_targets, - pos_labels) - losses.update(loss_mask) + if mask_feats.shape[0] > 0: + mask_pred = self.mask_head(mask_feats) + mask_targets = self.mask_head.get_target( + sampling_results, gt_masks, self.train_cfg.rcnn) + pos_labels = torch.cat( + [res.pos_gt_labels for res in sampling_results]) + loss_mask = self.mask_head.loss(mask_pred, mask_targets, + pos_labels) + losses.update(loss_mask) return losses diff --git a/tests/test_assigner.py b/tests/test_assigner.py new file mode 100644 index 0000000..50cf7d5 --- /dev/null +++ b/tests/test_assigner.py @@ -0,0 +1,261 @@ +""" +Tests the Assigner objects. + +CommandLine: + pytest tests/test_assigner.py + xdoctest tests/test_assigner.py zero + + + +""" +import torch + +from mmdet.core import MaxIoUAssigner +from mmdet.core.bbox.assigners import ApproxMaxIoUAssigner, PointAssigner + + +def test_max_iou_assigner(): + self = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([2, 3]) + assign_result = self.assign(bboxes, gt_bboxes, gt_labels=gt_labels) + assert len(assign_result.gt_inds) == 4 + assert len(assign_result.labels) == 4 + + expected_gt_inds = torch.LongTensor([1, 0, 2, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_max_iou_assigner_with_ignore(): + self = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ignore_iof_thr=0.5, + ignore_wrt_candidates=False, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_bboxes_ignore = torch.Tensor([ + [30, 30, 40, 40], + ]) + assign_result = self.assign( + bboxes, gt_bboxes, gt_bboxes_ignore=gt_bboxes_ignore) + + expected_gt_inds = torch.LongTensor([1, 0, 2, -1]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_max_iou_assigner_with_empty_gt(): + """ + Test corner case where an image might have no true detections + """ + self = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([]) + assign_result = self.assign(bboxes, gt_bboxes) + + expected_gt_inds = torch.LongTensor([0, 0, 0, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_max_iou_assigner_with_empty_boxes(): + """ + Test corner case where an network might predict no boxes + """ + self = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.empty((0, 4)) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([2, 3]) + + # Test with gt_labels + assign_result = self.assign(bboxes, gt_bboxes, gt_labels=gt_labels) + assert len(assign_result.gt_inds) == 0 + assert tuple(assign_result.labels.shape) == (0, ) + + # Test without gt_labels + assign_result = self.assign(bboxes, gt_bboxes, gt_labels=None) + assert len(assign_result.gt_inds) == 0 + assert assign_result.labels is None + + +def test_max_iou_assigner_with_empty_boxes_and_gt(): + """ + Test corner case where an network might predict no boxes and no gt + """ + self = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.empty((0, 4)) + gt_bboxes = torch.empty((0, 4)) + assign_result = self.assign(bboxes, gt_bboxes) + assert len(assign_result.gt_inds) == 0 + + +def test_point_assigner(): + self = PointAssigner() + points = torch.FloatTensor([ # [x, y, stride] + [0, 0, 1], + [10, 10, 1], + [5, 5, 1], + [32, 32, 1], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + assign_result = self.assign(points, gt_bboxes) + expected_gt_inds = torch.LongTensor([1, 2, 1, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_point_assigner_with_empty_gt(): + """ + Test corner case where an image might have no true detections + """ + self = PointAssigner() + points = torch.FloatTensor([ # [x, y, stride] + [0, 0, 1], + [10, 10, 1], + [5, 5, 1], + [32, 32, 1], + ]) + gt_bboxes = torch.FloatTensor([]) + assign_result = self.assign(points, gt_bboxes) + + expected_gt_inds = torch.LongTensor([0, 0, 0, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_point_assigner_with_empty_boxes_and_gt(): + """ + Test corner case where an image might predict no points and no gt + """ + self = PointAssigner() + points = torch.FloatTensor([]) + gt_bboxes = torch.FloatTensor([]) + assign_result = self.assign(points, gt_bboxes) + assert len(assign_result.gt_inds) == 0 + + +def test_approx_iou_assigner(): + self = ApproxMaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + approxs_per_octave = 1 + approxs = bboxes + squares = bboxes + assign_result = self.assign(approxs, squares, approxs_per_octave, + gt_bboxes) + + expected_gt_inds = torch.LongTensor([1, 0, 2, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_approx_iou_assigner_with_empty_gt(): + """ + Test corner case where an image might have no true detections + """ + self = ApproxMaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([]) + approxs_per_octave = 1 + approxs = bboxes + squares = bboxes + assign_result = self.assign(approxs, squares, approxs_per_octave, + gt_bboxes) + + expected_gt_inds = torch.LongTensor([0, 0, 0, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_approx_iou_assigner_with_empty_boxes(): + """ + Test corner case where an network might predict no boxes + """ + self = ApproxMaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.empty((0, 4)) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + approxs_per_octave = 1 + approxs = bboxes + squares = bboxes + assign_result = self.assign(approxs, squares, approxs_per_octave, + gt_bboxes) + assert len(assign_result.gt_inds) == 0 + + +def test_approx_iou_assigner_with_empty_boxes_and_gt(): + """ + Test corner case where an network might predict no boxes and no gt + """ + self = ApproxMaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + bboxes = torch.empty((0, 4)) + gt_bboxes = torch.empty((0, 4)) + approxs_per_octave = 1 + approxs = bboxes + squares = bboxes + assign_result = self.assign(approxs, squares, approxs_per_octave, + gt_bboxes) + assert len(assign_result.gt_inds) == 0 diff --git a/tests/test_forward.py b/tests/test_forward.py index dede4ce..5ba56bf 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -168,6 +168,164 @@ def test_retina_ghm_forward(): batch_results.append(result) +def test_cascade_forward(): + try: + from torchvision import _C as C # NOQA + except ImportError: + import pytest + raise pytest.skip('requires torchvision on cpu') + + model, train_cfg, test_cfg = _get_detector_cfg( + 'cascade_rcnn_r50_fpn_1x.py') + model['pretrained'] = None + # torchvision roi align supports CPU + model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True + + from mmdet.models import build_detector + detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) + + input_shape = (1, 3, 256, 256) + + # Test forward train with a non-empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[10]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + from mmdet.apis.train import parse_losses + total_loss = float(parse_losses(losses)[0].item()) + assert total_loss > 0 + + # Test forward train with an empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + from mmdet.apis.train import parse_losses + total_loss = float(parse_losses(losses)[0].item()) + assert total_loss > 0 + + +def test_faster_rcnn_forward(): + try: + from torchvision import _C as C # NOQA + except ImportError: + import pytest + raise pytest.skip('requires torchvision on cpu') + + model, train_cfg, test_cfg = _get_detector_cfg('faster_rcnn_r50_fpn_1x.py') + model['pretrained'] = None + # torchvision roi align supports CPU + model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True + + from mmdet.models import build_detector + detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) + + input_shape = (1, 3, 256, 256) + + # Test forward train with a non-empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[10]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + from mmdet.apis.train import parse_losses + total_loss = float(parse_losses(losses)[0].item()) + assert total_loss > 0 + + # Test forward train with an empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + from mmdet.apis.train import parse_losses + total_loss = float(parse_losses(losses)[0].item()) + assert total_loss > 0 + + +def test_faster_rcnn_ohem_forward(): + try: + from torchvision import _C as C # NOQA + except ImportError: + import pytest + raise pytest.skip('requires torchvision on cpu') + + model, train_cfg, test_cfg = _get_detector_cfg( + 'faster_rcnn_ohem_r50_fpn_1x.py') + model['pretrained'] = None + # torchvision roi align supports CPU + model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True + + from mmdet.models import build_detector + detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) + + input_shape = (1, 3, 256, 256) + + # Test forward train with a non-empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[10]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + from mmdet.apis.train import parse_losses + total_loss = float(parse_losses(losses)[0].item()) + assert total_loss > 0 + + # Test forward train with an empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + from mmdet.apis.train import parse_losses + total_loss = float(parse_losses(losses)[0].item()) + assert total_loss > 0 + + def _demo_mm_inputs(input_shape=(1, 3, 300, 300), num_items=None, num_classes=10): # yapf: disable """ diff --git a/tests/test_heads.py b/tests/test_heads.py new file mode 100644 index 0000000..5c14314 --- /dev/null +++ b/tests/test_heads.py @@ -0,0 +1,171 @@ +import mmcv +import torch + +from mmdet.core import build_assigner, build_sampler +from mmdet.models.anchor_heads import AnchorHead +from mmdet.models.bbox_heads import BBoxHead + + +def test_anchor_head_loss(): + """ + Tests anchor head loss when truth is empty and non-empty + """ + self = AnchorHead(num_classes=4, in_channels=1) + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + + cfg = mmcv.Config({ + 'assigner': { + 'type': 'MaxIoUAssigner', + 'pos_iou_thr': 0.7, + 'neg_iou_thr': 0.3, + 'min_pos_iou': 0.3, + 'ignore_iof_thr': -1 + }, + 'sampler': { + 'type': 'RandomSampler', + 'num': 256, + 'pos_fraction': 0.5, + 'neg_pos_ub': -1, + 'add_gt_as_proposals': False + }, + 'allowed_border': 0, + 'pos_weight': -1, + 'debug': False + }) + + # Anchor head expects a multiple levels of features per image + feat = [ + torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2))) + for i in range(len(self.anchor_generators)) + ] + cls_scores, bbox_preds = self.forward(feat) + + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, cfg, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_loss = sum(empty_gt_losses['loss_bbox']) + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, cfg, gt_bboxes_ignore) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_loss = sum(one_gt_losses['loss_bbox']) + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' + + +def test_bbox_head_loss(): + """ + Tests bbox head loss when truth is empty and non-empty + """ + self = BBoxHead(in_channels=8, roi_feat_size=3) + + num_imgs = 1 + feat = torch.rand(1, 1, 3, 3) + + # Dummy proposals + proposal_list = [ + torch.Tensor([[23.6667, 23.8757, 228.6326, 153.8874]]), + ] + + target_cfg = mmcv.Config({'pos_weight': 1}) + + def _dummy_bbox_sampling(proposal_list, gt_bboxes, gt_labels): + """ + Create sample results that can be passed to BBoxHead.get_target + """ + assign_config = { + 'type': 'MaxIoUAssigner', + 'pos_iou_thr': 0.5, + 'neg_iou_thr': 0.5, + 'min_pos_iou': 0.5, + 'ignore_iof_thr': -1 + } + sampler_config = { + 'type': 'RandomSampler', + 'num': 512, + 'pos_fraction': 0.25, + 'neg_pos_ub': -1, + 'add_gt_as_proposals': True + } + bbox_assigner = build_assigner(assign_config) + bbox_sampler = build_sampler(sampler_config) + gt_bboxes_ignore = [None for _ in range(num_imgs)] + sampling_results = [] + for i in range(num_imgs): + assign_result = bbox_assigner.assign(proposal_list[i], + gt_bboxes[i], + gt_bboxes_ignore[i], + gt_labels[i]) + sampling_result = bbox_sampler.sample( + assign_result, + proposal_list[i], + gt_bboxes[i], + gt_labels[i], + feats=feat) + sampling_results.append(sampling_result) + return sampling_results + + # Test bbox loss when truth is empty + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + + sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes, + gt_labels) + + bbox_targets = self.get_target(sampling_results, gt_bboxes, gt_labels, + target_cfg) + labels, label_weights, bbox_targets, bbox_weights = bbox_targets + + # Create dummy features "extracted" for each sampled bbox + num_sampled = sum(len(res.bboxes) for res in sampling_results) + dummy_feats = torch.rand(num_sampled, 8 * 3 * 3) + cls_scores, bbox_preds = self.forward(dummy_feats) + + losses = self.loss(cls_scores, bbox_preds, labels, label_weights, + bbox_targets, bbox_weights) + assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero' + assert losses.get('loss_bbox', 0) == 0, 'empty gt loss should be zero' + + # Test bbox loss when truth is non-empty + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + + sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes, + gt_labels) + + bbox_targets = self.get_target(sampling_results, gt_bboxes, gt_labels, + target_cfg) + labels, label_weights, bbox_targets, bbox_weights = bbox_targets + + # Create dummy features "extracted" for each sampled bbox + num_sampled = sum(len(res.bboxes) for res in sampling_results) + dummy_feats = torch.rand(num_sampled, 8 * 3 * 3) + cls_scores, bbox_preds = self.forward(dummy_feats) + + losses = self.loss(cls_scores, bbox_preds, labels, label_weights, + bbox_targets, bbox_weights) + assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero' + assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero' diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 0000000..c375d6e --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,235 @@ +import torch + +from mmdet.core import MaxIoUAssigner +from mmdet.core.bbox.samplers import OHEMSampler, RandomSampler + + +def test_random_sampler(): + assigner = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ignore_iof_thr=0.5, + ignore_wrt_candidates=False, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([1, 2]) + gt_bboxes_ignore = torch.Tensor([ + [30, 30, 40, 40], + ]) + assign_result = assigner.assign( + bboxes, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore, + gt_labels=gt_labels) + + sampler = RandomSampler( + num=10, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=True) + + sample_result = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + + assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) + assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) + + +def test_random_sampler_empty_gt(): + assigner = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ignore_iof_thr=0.5, + ignore_wrt_candidates=False, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.empty(0, 4) + gt_labels = torch.empty(0, ).long() + assign_result = assigner.assign(bboxes, gt_bboxes, gt_labels=gt_labels) + + sampler = RandomSampler( + num=10, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=True) + + sample_result = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + + assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) + assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) + + +def test_random_sampler_empty_pred(): + assigner = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ignore_iof_thr=0.5, + ignore_wrt_candidates=False, + ) + bboxes = torch.empty(0, 4) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([1, 2]) + assign_result = assigner.assign(bboxes, gt_bboxes, gt_labels=gt_labels) + + sampler = RandomSampler( + num=10, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=True) + + sample_result = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + + assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) + assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) + + +def _context_for_ohem(): + try: + from test_forward import _get_detector_cfg + except ImportError: + # Hack: grab testing utils from test_forward to make a context for ohem + import sys + from os.path import dirname + sys.path.insert(0, dirname(__file__)) + from test_forward import _get_detector_cfg + model, train_cfg, test_cfg = _get_detector_cfg( + 'faster_rcnn_ohem_r50_fpn_1x.py') + model['pretrained'] = None + # torchvision roi align supports CPU + model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True + from mmdet.models import build_detector + context = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) + return context + + +def test_ohem_sampler(): + + assigner = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ignore_iof_thr=0.5, + ignore_wrt_candidates=False, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([1, 2]) + gt_bboxes_ignore = torch.Tensor([ + [30, 30, 40, 40], + ]) + assign_result = assigner.assign( + bboxes, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore, + gt_labels=gt_labels) + + context = _context_for_ohem() + + sampler = OHEMSampler( + num=10, + pos_fraction=0.5, + context=context, + neg_pos_ub=-1, + add_gt_as_proposals=True) + + feats = [torch.rand(1, 256, int(2**i), int(2**i)) for i in [6, 5, 4, 3, 2]] + sample_result = sampler.sample( + assign_result, bboxes, gt_bboxes, gt_labels, feats=feats) + + assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) + assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) + + +def test_ohem_sampler_empty_gt(): + + assigner = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ignore_iof_thr=0.5, + ignore_wrt_candidates=False, + ) + bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.empty(0, 4) + gt_labels = torch.LongTensor([]) + gt_bboxes_ignore = torch.Tensor([]) + assign_result = assigner.assign( + bboxes, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore, + gt_labels=gt_labels) + + context = _context_for_ohem() + + sampler = OHEMSampler( + num=10, + pos_fraction=0.5, + context=context, + neg_pos_ub=-1, + add_gt_as_proposals=True) + + feats = [torch.rand(1, 256, int(2**i), int(2**i)) for i in [6, 5, 4, 3, 2]] + + sample_result = sampler.sample( + assign_result, bboxes, gt_bboxes, gt_labels, feats=feats) + + assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) + assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) + + +def test_ohem_sampler_empty_pred(): + assigner = MaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ignore_iof_thr=0.5, + ignore_wrt_candidates=False, + ) + bboxes = torch.empty(0, 4) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_labels = torch.LongTensor([1, 2, 2, 3]) + gt_bboxes_ignore = torch.Tensor([]) + assign_result = assigner.assign( + bboxes, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore, + gt_labels=gt_labels) + + context = _context_for_ohem() + + sampler = OHEMSampler( + num=10, + pos_fraction=0.5, + context=context, + neg_pos_ub=-1, + add_gt_as_proposals=True) + + feats = [torch.rand(1, 256, int(2**i), int(2**i)) for i in [6, 5, 4, 3, 2]] + + sample_result = sampler.sample( + assign_result, bboxes, gt_bboxes, gt_labels, feats=feats) + + assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) + assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) -- GitLab