From 8c0ecd1ea92c83f2ee058fd8eeabaa750c564a9a Mon Sep 17 00:00:00 2001
From: Jon Crall <erotemic@gmail.com>
Date: Sun, 12 Jan 2020 23:53:45 -0500
Subject: [PATCH] Fix issue in refine_bboxes and add doctest (#1962)

* Fix issue in refine_bboxes and add doctest

* fix pillow version on travis

* Fixes based on review

* Fix errors in doctest and add comprehensive unit test

* Fix linting error
---
 .travis.yml                          |   1 +
 mmdet/core/bbox/demodata.py          |  65 +++++++++++
 mmdet/models/bbox_heads/bbox_head.py |  45 ++++++-
 tests/requirements.txt               |   3 +
 tests/test_heads.py                  | 169 +++++++++++++++++++++++++++
 5 files changed, 279 insertions(+), 4 deletions(-)
 create mode 100644 mmdet/core/bbox/demodata.py

diff --git a/.travis.yml b/.travis.yml
index d51fc0d..05f7fdb 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -25,6 +25,7 @@ before_install:
 
 install:
   - pip install Cython torch==1.2
+  - pip install Pillow==6.2.2
   - pip install -r requirements.txt
   - pip install -r tests/requirements.txt
 
diff --git a/mmdet/core/bbox/demodata.py b/mmdet/core/bbox/demodata.py
new file mode 100644
index 0000000..d59d654
--- /dev/null
+++ b/mmdet/core/bbox/demodata.py
@@ -0,0 +1,65 @@
+import numpy as np
+import torch
+
+
+def ensure_rng(rng=None):
+    """
+    Simple version of the ``kwarray.ensure_rng``
+
+    Args:
+        rng (int | numpy.random.RandomState | None):
+            if None, then defaults to the global rng. Otherwise this can be an
+            integer or a RandomState class
+    Returns:
+        (numpy.random.RandomState) : rng -
+            a numpy random number generator
+
+    References:
+        https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270
+    """
+
+    if rng is None:
+        rng = np.random.mtrand._rand
+    elif isinstance(rng, int):
+        rng = np.random.RandomState(rng)
+    else:
+        rng = rng
+    return rng
+
+
+def random_boxes(num=1, scale=1, rng=None):
+    """
+    Simple version of ``kwimage.Boxes.random``
+
+    Returns:
+        Tensor: shape (n, 4) in x1, y1, x2, y2 format.
+
+    References:
+        https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
+
+    Example:
+        >>> num = 3
+        >>> scale = 512
+        >>> rng = 0
+        >>> boxes = random_boxes(num, scale, rng)
+        >>> print(boxes)
+        tensor([[280.9925, 278.9802, 308.6148, 366.1769],
+                [216.9113, 330.6978, 224.0446, 456.5878],
+                [405.3632, 196.3221, 493.3953, 270.7942]])
+    """
+    rng = ensure_rng(rng)
+
+    tlbr = rng.rand(num, 4).astype(np.float32)
+
+    tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
+    tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
+    br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
+    br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
+
+    tlbr[:, 0] = tl_x * scale
+    tlbr[:, 1] = tl_y * scale
+    tlbr[:, 2] = br_x * scale
+    tlbr[:, 3] = br_y * scale
+
+    boxes = torch.from_numpy(tlbr)
+    return boxes
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index ced0ad1..8ab878a 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -178,7 +178,8 @@ class BBoxHead(nn.Module):
 
         Args:
             rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
-                and bs is the sampled RoIs per image.
+                and bs is the sampled RoIs per image. The first column is
+                the image id and the next 4 columns are x1, y1, x2, y2.
             labels (Tensor): Shape (n*bs, ).
             bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
             pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
@@ -187,13 +188,48 @@ class BBoxHead(nn.Module):
 
         Returns:
             list[Tensor]: Refined bboxes of each image in a mini-batch.
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:kwarray)
+            >>> import kwarray
+            >>> import numpy as np
+            >>> from mmdet.core.bbox.demodata import random_boxes
+            >>> self = BBoxHead(reg_class_agnostic=True)
+            >>> n_roi = 2
+            >>> n_img = 4
+            >>> scale = 512
+            >>> rng = np.random.RandomState(0)
+            >>> img_metas = [{'img_shape': (scale, scale)}
+            ...              for _ in range(n_img)]
+            >>> # Create rois in the expected format
+            >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
+            >>> img_ids = torch.randint(0, n_img, (n_roi,))
+            >>> img_ids = img_ids.float()
+            >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
+            >>> # Create other args
+            >>> labels = torch.randint(0, 2, (n_roi,)).long()
+            >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
+            >>> # For each image, pretend random positive boxes are gts
+            >>> is_label_pos = (labels.numpy() > 0).astype(np.int)
+            >>> lbl_per_img = kwarray.group_items(is_label_pos,
+            ...                                   img_ids.numpy())
+            >>> pos_per_img = [sum(lbl_per_img.get(gid, []))
+            ...                for gid in range(n_img)]
+            >>> pos_is_gts = [
+            >>>     torch.randint(0, 2, (npos,)).byte().sort(
+            >>>         descending=True)[0]
+            >>>     for npos in pos_per_img
+            >>> ]
+            >>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
+            >>>                    pos_is_gts, img_metas)
+            >>> print(bboxes_list)
         """
         img_ids = rois[:, 0].long().unique(sorted=True)
-        assert img_ids.numel() == len(img_metas)
+        assert img_ids.numel() <= len(img_metas)
 
         bboxes_list = []
         for i in range(len(img_metas)):
-            inds = torch.nonzero(rois[:, 0] == i).squeeze()
+            inds = torch.nonzero(rois[:, 0] == i).squeeze(dim=1)
             num_rois = inds.numel()
 
             bboxes_ = rois[inds, 1:]
@@ -204,6 +240,7 @@ class BBoxHead(nn.Module):
 
             bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
                                            img_meta_)
+
             # filter gt bboxes
             pos_keep = 1 - pos_is_gts_
             keep_inds = pos_is_gts_.new_ones(num_rois)
@@ -226,7 +263,7 @@ class BBoxHead(nn.Module):
         Returns:
             Tensor: Regressed bboxes, the same shape as input rois.
         """
-        assert rois.size(1) == 4 or rois.size(1) == 5
+        assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)
 
         if not self.reg_class_agnostic:
             label = label * 4
diff --git a/tests/requirements.txt b/tests/requirements.txt
index ff60968..6f8c22d 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -5,3 +5,6 @@ pytest-cov
 codecov
 xdoctest >= 0.10.0
 asynctest
+
+# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
+kwarray
diff --git a/tests/test_heads.py b/tests/test_heads.py
index 5c14314..b1e4cee 100644
--- a/tests/test_heads.py
+++ b/tests/test_heads.py
@@ -169,3 +169,172 @@ def test_bbox_head_loss():
                        bbox_targets, bbox_weights)
     assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero'
     assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero'
+
+
+def test_refine_boxes():
+    """
+    Mirrors the doctest in
+    ``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for
+    multiple values of n_roi / n_img.
+    """
+    self = BBoxHead(reg_class_agnostic=True)
+
+    test_settings = [
+
+        # Corner case: less rois than images
+        {
+            'n_roi': 2,
+            'n_img': 4,
+            'rng': 34285940
+        },
+
+        # Corner case: no images
+        {
+            'n_roi': 0,
+            'n_img': 0,
+            'rng': 52925222
+        },
+
+        # Corner cases: few images / rois
+        {
+            'n_roi': 1,
+            'n_img': 1,
+            'rng': 1200281
+        },
+        {
+            'n_roi': 2,
+            'n_img': 1,
+            'rng': 1200282
+        },
+        {
+            'n_roi': 2,
+            'n_img': 2,
+            'rng': 1200283
+        },
+        {
+            'n_roi': 1,
+            'n_img': 2,
+            'rng': 1200284
+        },
+
+        # Corner case: no rois few images
+        {
+            'n_roi': 0,
+            'n_img': 1,
+            'rng': 23955860
+        },
+        {
+            'n_roi': 0,
+            'n_img': 2,
+            'rng': 25830516
+        },
+
+        # Corner case: no rois many images
+        {
+            'n_roi': 0,
+            'n_img': 10,
+            'rng': 671346
+        },
+        {
+            'n_roi': 0,
+            'n_img': 20,
+            'rng': 699807
+        },
+
+        # Corner case: similar num rois and images
+        {
+            'n_roi': 20,
+            'n_img': 20,
+            'rng': 1200238
+        },
+        {
+            'n_roi': 10,
+            'n_img': 20,
+            'rng': 1200238
+        },
+        {
+            'n_roi': 5,
+            'n_img': 5,
+            'rng': 1200238
+        },
+
+        # ----------------------------------
+        # Common case: more rois than images
+        {
+            'n_roi': 100,
+            'n_img': 1,
+            'rng': 337156
+        },
+        {
+            'n_roi': 150,
+            'n_img': 2,
+            'rng': 275898
+        },
+        {
+            'n_roi': 500,
+            'n_img': 5,
+            'rng': 4903221
+        },
+    ]
+
+    for demokw in test_settings:
+        try:
+            n_roi = demokw['n_roi']
+            n_img = demokw['n_img']
+            rng = demokw['rng']
+
+            print('Test refine_boxes case: {!r}'.format(demokw))
+            tup = _demodata_refine_boxes(n_roi, n_img, rng=rng)
+            rois, labels, bbox_preds, pos_is_gts, img_metas = tup
+            bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
+                                             pos_is_gts, img_metas)
+            assert len(bboxes_list) == n_img
+            assert sum(map(len, bboxes_list)) <= n_roi
+            assert all(b.shape[1] == 4 for b in bboxes_list)
+        except Exception:
+            print('Test failed with demokw={!r}'.format(demokw))
+            raise
+
+
+def _demodata_refine_boxes(n_roi, n_img, rng=0):
+    """
+    Create random test data for the
+    ``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` method
+    """
+    import numpy as np
+    from mmdet.core.bbox.demodata import random_boxes
+    from mmdet.core.bbox.demodata import ensure_rng
+    try:
+        import kwarray
+    except ImportError:
+        import pytest
+        pytest.skip('kwarray is required for this test')
+    scale = 512
+    rng = ensure_rng(rng)
+    img_metas = [{'img_shape': (scale, scale)} for _ in range(n_img)]
+    # Create rois in the expected format
+    roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
+    if n_img == 0:
+        assert n_roi == 0, 'cannot have any rois if there are no images'
+        img_ids = torch.empty((0, ), dtype=torch.long)
+        roi_boxes = torch.empty((0, 4), dtype=torch.float32)
+    else:
+        img_ids = rng.randint(0, n_img, (n_roi, ))
+        img_ids = torch.from_numpy(img_ids)
+    rois = torch.cat([img_ids[:, None].float(), roi_boxes], dim=1)
+    # Create other args
+    labels = rng.randint(0, 2, (n_roi, ))
+    labels = torch.from_numpy(labels).long()
+    bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
+    # For each image, pretend random positive boxes are gts
+    is_label_pos = (labels.numpy() > 0).astype(np.int)
+    lbl_per_img = kwarray.group_items(is_label_pos, img_ids.numpy())
+    pos_per_img = [sum(lbl_per_img.get(gid, [])) for gid in range(n_img)]
+    # randomly generate with numpy then sort with torch
+    _pos_is_gts = [
+        rng.randint(0, 2, (npos, )).astype(np.uint8) for npos in pos_per_img
+    ]
+    pos_is_gts = [
+        torch.from_numpy(p).sort(descending=True)[0] for p in _pos_is_gts
+    ]
+    return rois, labels, bbox_preds, pos_is_gts, img_metas
-- 
GitLab