From 8c0ecd1ea92c83f2ee058fd8eeabaa750c564a9a Mon Sep 17 00:00:00 2001 From: Jon Crall <erotemic@gmail.com> Date: Sun, 12 Jan 2020 23:53:45 -0500 Subject: [PATCH] Fix issue in refine_bboxes and add doctest (#1962) * Fix issue in refine_bboxes and add doctest * fix pillow version on travis * Fixes based on review * Fix errors in doctest and add comprehensive unit test * Fix linting error --- .travis.yml | 1 + mmdet/core/bbox/demodata.py | 65 +++++++++++ mmdet/models/bbox_heads/bbox_head.py | 45 ++++++- tests/requirements.txt | 3 + tests/test_heads.py | 169 +++++++++++++++++++++++++++ 5 files changed, 279 insertions(+), 4 deletions(-) create mode 100644 mmdet/core/bbox/demodata.py diff --git a/.travis.yml b/.travis.yml index d51fc0d..05f7fdb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,6 +25,7 @@ before_install: install: - pip install Cython torch==1.2 + - pip install Pillow==6.2.2 - pip install -r requirements.txt - pip install -r tests/requirements.txt diff --git a/mmdet/core/bbox/demodata.py b/mmdet/core/bbox/demodata.py new file mode 100644 index 0000000..d59d654 --- /dev/null +++ b/mmdet/core/bbox/demodata.py @@ -0,0 +1,65 @@ +import numpy as np +import torch + + +def ensure_rng(rng=None): + """ + Simple version of the ``kwarray.ensure_rng`` + + Args: + rng (int | numpy.random.RandomState | None): + if None, then defaults to the global rng. Otherwise this can be an + integer or a RandomState class + Returns: + (numpy.random.RandomState) : rng - + a numpy random number generator + + References: + https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 + """ + + if rng is None: + rng = np.random.mtrand._rand + elif isinstance(rng, int): + rng = np.random.RandomState(rng) + else: + rng = rng + return rng + + +def random_boxes(num=1, scale=1, rng=None): + """ + Simple version of ``kwimage.Boxes.random`` + + Returns: + Tensor: shape (n, 4) in x1, y1, x2, y2 format. + + References: + https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390 + + Example: + >>> num = 3 + >>> scale = 512 + >>> rng = 0 + >>> boxes = random_boxes(num, scale, rng) + >>> print(boxes) + tensor([[280.9925, 278.9802, 308.6148, 366.1769], + [216.9113, 330.6978, 224.0446, 456.5878], + [405.3632, 196.3221, 493.3953, 270.7942]]) + """ + rng = ensure_rng(rng) + + tlbr = rng.rand(num, 4).astype(np.float32) + + tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2]) + tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3]) + br_x = np.maximum(tlbr[:, 0], tlbr[:, 2]) + br_y = np.maximum(tlbr[:, 1], tlbr[:, 3]) + + tlbr[:, 0] = tl_x * scale + tlbr[:, 1] = tl_y * scale + tlbr[:, 2] = br_x * scale + tlbr[:, 3] = br_y * scale + + boxes = torch.from_numpy(tlbr) + return boxes diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index ced0ad1..8ab878a 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -178,7 +178,8 @@ class BBoxHead(nn.Module): Args: rois (Tensor): Shape (n*bs, 5), where n is image number per GPU, - and bs is the sampled RoIs per image. + and bs is the sampled RoIs per image. The first column is + the image id and the next 4 columns are x1, y1, x2, y2. labels (Tensor): Shape (n*bs, ). bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class). pos_is_gts (list[Tensor]): Flags indicating if each positive bbox @@ -187,13 +188,48 @@ class BBoxHead(nn.Module): Returns: list[Tensor]: Refined bboxes of each image in a mini-batch. + + Example: + >>> # xdoctest: +REQUIRES(module:kwarray) + >>> import kwarray + >>> import numpy as np + >>> from mmdet.core.bbox.demodata import random_boxes + >>> self = BBoxHead(reg_class_agnostic=True) + >>> n_roi = 2 + >>> n_img = 4 + >>> scale = 512 + >>> rng = np.random.RandomState(0) + >>> img_metas = [{'img_shape': (scale, scale)} + ... for _ in range(n_img)] + >>> # Create rois in the expected format + >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng) + >>> img_ids = torch.randint(0, n_img, (n_roi,)) + >>> img_ids = img_ids.float() + >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1) + >>> # Create other args + >>> labels = torch.randint(0, 2, (n_roi,)).long() + >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng) + >>> # For each image, pretend random positive boxes are gts + >>> is_label_pos = (labels.numpy() > 0).astype(np.int) + >>> lbl_per_img = kwarray.group_items(is_label_pos, + ... img_ids.numpy()) + >>> pos_per_img = [sum(lbl_per_img.get(gid, [])) + ... for gid in range(n_img)] + >>> pos_is_gts = [ + >>> torch.randint(0, 2, (npos,)).byte().sort( + >>> descending=True)[0] + >>> for npos in pos_per_img + >>> ] + >>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds, + >>> pos_is_gts, img_metas) + >>> print(bboxes_list) """ img_ids = rois[:, 0].long().unique(sorted=True) - assert img_ids.numel() == len(img_metas) + assert img_ids.numel() <= len(img_metas) bboxes_list = [] for i in range(len(img_metas)): - inds = torch.nonzero(rois[:, 0] == i).squeeze() + inds = torch.nonzero(rois[:, 0] == i).squeeze(dim=1) num_rois = inds.numel() bboxes_ = rois[inds, 1:] @@ -204,6 +240,7 @@ class BBoxHead(nn.Module): bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, img_meta_) + # filter gt bboxes pos_keep = 1 - pos_is_gts_ keep_inds = pos_is_gts_.new_ones(num_rois) @@ -226,7 +263,7 @@ class BBoxHead(nn.Module): Returns: Tensor: Regressed bboxes, the same shape as input rois. """ - assert rois.size(1) == 4 or rois.size(1) == 5 + assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape) if not self.reg_class_agnostic: label = label * 4 diff --git a/tests/requirements.txt b/tests/requirements.txt index ff60968..6f8c22d 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -5,3 +5,6 @@ pytest-cov codecov xdoctest >= 0.10.0 asynctest + +# Note: used for kwarray.group_items, this may be ported to mmcv in the future. +kwarray diff --git a/tests/test_heads.py b/tests/test_heads.py index 5c14314..b1e4cee 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -169,3 +169,172 @@ def test_bbox_head_loss(): bbox_targets, bbox_weights) assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero' assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero' + + +def test_refine_boxes(): + """ + Mirrors the doctest in + ``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for + multiple values of n_roi / n_img. + """ + self = BBoxHead(reg_class_agnostic=True) + + test_settings = [ + + # Corner case: less rois than images + { + 'n_roi': 2, + 'n_img': 4, + 'rng': 34285940 + }, + + # Corner case: no images + { + 'n_roi': 0, + 'n_img': 0, + 'rng': 52925222 + }, + + # Corner cases: few images / rois + { + 'n_roi': 1, + 'n_img': 1, + 'rng': 1200281 + }, + { + 'n_roi': 2, + 'n_img': 1, + 'rng': 1200282 + }, + { + 'n_roi': 2, + 'n_img': 2, + 'rng': 1200283 + }, + { + 'n_roi': 1, + 'n_img': 2, + 'rng': 1200284 + }, + + # Corner case: no rois few images + { + 'n_roi': 0, + 'n_img': 1, + 'rng': 23955860 + }, + { + 'n_roi': 0, + 'n_img': 2, + 'rng': 25830516 + }, + + # Corner case: no rois many images + { + 'n_roi': 0, + 'n_img': 10, + 'rng': 671346 + }, + { + 'n_roi': 0, + 'n_img': 20, + 'rng': 699807 + }, + + # Corner case: similar num rois and images + { + 'n_roi': 20, + 'n_img': 20, + 'rng': 1200238 + }, + { + 'n_roi': 10, + 'n_img': 20, + 'rng': 1200238 + }, + { + 'n_roi': 5, + 'n_img': 5, + 'rng': 1200238 + }, + + # ---------------------------------- + # Common case: more rois than images + { + 'n_roi': 100, + 'n_img': 1, + 'rng': 337156 + }, + { + 'n_roi': 150, + 'n_img': 2, + 'rng': 275898 + }, + { + 'n_roi': 500, + 'n_img': 5, + 'rng': 4903221 + }, + ] + + for demokw in test_settings: + try: + n_roi = demokw['n_roi'] + n_img = demokw['n_img'] + rng = demokw['rng'] + + print('Test refine_boxes case: {!r}'.format(demokw)) + tup = _demodata_refine_boxes(n_roi, n_img, rng=rng) + rois, labels, bbox_preds, pos_is_gts, img_metas = tup + bboxes_list = self.refine_bboxes(rois, labels, bbox_preds, + pos_is_gts, img_metas) + assert len(bboxes_list) == n_img + assert sum(map(len, bboxes_list)) <= n_roi + assert all(b.shape[1] == 4 for b in bboxes_list) + except Exception: + print('Test failed with demokw={!r}'.format(demokw)) + raise + + +def _demodata_refine_boxes(n_roi, n_img, rng=0): + """ + Create random test data for the + ``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` method + """ + import numpy as np + from mmdet.core.bbox.demodata import random_boxes + from mmdet.core.bbox.demodata import ensure_rng + try: + import kwarray + except ImportError: + import pytest + pytest.skip('kwarray is required for this test') + scale = 512 + rng = ensure_rng(rng) + img_metas = [{'img_shape': (scale, scale)} for _ in range(n_img)] + # Create rois in the expected format + roi_boxes = random_boxes(n_roi, scale=scale, rng=rng) + if n_img == 0: + assert n_roi == 0, 'cannot have any rois if there are no images' + img_ids = torch.empty((0, ), dtype=torch.long) + roi_boxes = torch.empty((0, 4), dtype=torch.float32) + else: + img_ids = rng.randint(0, n_img, (n_roi, )) + img_ids = torch.from_numpy(img_ids) + rois = torch.cat([img_ids[:, None].float(), roi_boxes], dim=1) + # Create other args + labels = rng.randint(0, 2, (n_roi, )) + labels = torch.from_numpy(labels).long() + bbox_preds = random_boxes(n_roi, scale=scale, rng=rng) + # For each image, pretend random positive boxes are gts + is_label_pos = (labels.numpy() > 0).astype(np.int) + lbl_per_img = kwarray.group_items(is_label_pos, img_ids.numpy()) + pos_per_img = [sum(lbl_per_img.get(gid, [])) for gid in range(n_img)] + # randomly generate with numpy then sort with torch + _pos_is_gts = [ + rng.randint(0, 2, (npos, )).astype(np.uint8) for npos in pos_per_img + ] + pos_is_gts = [ + torch.from_numpy(p).sort(descending=True)[0] for p in _pos_is_gts + ] + return rois, labels, bbox_preds, pos_is_gts, img_metas -- GitLab