From b69667001f250a54a37129a000a8d5160e047239 Mon Sep 17 00:00:00 2001
From: Jon Crall <erotemic@gmail.com>
Date: Tue, 24 Dec 2019 04:02:28 -0500
Subject: [PATCH] Allow for images to contain zero true detections (#1531)

* Allow for images to contain zero true detections

* Allow for empty assignment in PointAssigner

* Allow ApproxMaxIouAssigner to return an empty result

* Fix CascadeRNN forward when entire batch has no truth

* Correctly assign boxes to background when there is no truth

* Fix assignment tests

* Make flatten robust

* Fix bbox loss with empty pred/truth

* Fix logic error in BBoxHead.loss

* Add tests for empty truth cases

* tests faster rcnn empty forward

* Skip roipool forward tests if torchvision is not installed

* Add tests for bbox/anchor heads

* Consolidate test_forward and test_forward2

* Fix assign_results.labels = None when gt_labels is given; Add test for this case

* Fix OHEM Sampler with zero truth

* remove xdev

* resolve 3 reviews

* Fix flake8

* refactoring

* fix yaml format

* add filter flag

* minor fix

* delete redundant code in load anno

* fix flake8 errors

* quick fix for empty truth with masks

* fix yapf error

* fix mask padding for empty masks

Co-authored-by: Cao Yuhang <yhcao6@gmail.com>
Co-authored-by: Kai Chen <chenkaidev@gmail.com>
---
 .../bbox/assigners/approx_max_iou_assigner.py |  14 +-
 mmdet/core/bbox/assigners/assign_result.py    |  75 ++++-
 mmdet/core/bbox/assigners/max_iou_assigner.py |  31 ++-
 mmdet/core/bbox/assigners/point_assigner.py   |  20 +-
 mmdet/core/bbox/geometry.py                   |  31 ++-
 mmdet/core/bbox/samplers/base_sampler.py      |   2 +-
 mmdet/datasets/coco.py                        |   2 +-
 mmdet/datasets/custom.py                      |   6 +-
 mmdet/datasets/pipelines/loading.py           |  15 +-
 mmdet/datasets/pipelines/transforms.py        |   5 +-
 mmdet/models/bbox_heads/bbox_head.py          |  40 +--
 mmdet/models/bbox_heads/convfc_bbox_head.py   |   8 +-
 mmdet/models/detectors/cascade_rcnn.py        |   6 +
 mmdet/models/detectors/two_stage.py           |  18 +-
 tests/test_assigner.py                        | 261 ++++++++++++++++++
 tests/test_forward.py                         | 158 +++++++++++
 tests/test_heads.py                           | 171 ++++++++++++
 tests/test_sampler.py                         | 235 ++++++++++++++++
 18 files changed, 1032 insertions(+), 66 deletions(-)
 create mode 100644 tests/test_assigner.py
 create mode 100644 tests/test_heads.py
 create mode 100644 tests/test_sampler.py

diff --git a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
index a52ed26..e7d3510 100644
--- a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
+++ b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
@@ -74,9 +74,9 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
 
         Args:
             approxs (Tensor): Bounding boxes to be assigned,
-        shape(approxs_per_octave*n, 4).
+                shape(approxs_per_octave*n, 4).
             squares (Tensor): Base Bounding boxes to be assigned,
-        shape(n, 4).
+                shape(n, 4).
             approxs_per_octave (int): number of approxs per octave
             gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
             gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
@@ -86,11 +86,15 @@ class ApproxMaxIoUAssigner(MaxIoUAssigner):
         Returns:
             :obj:`AssignResult`: The assign result.
         """
-
-        if squares.shape[0] == 0 or gt_bboxes.shape[0] == 0:
-            raise ValueError('No gt or approxs')
         num_squares = squares.size(0)
         num_gts = gt_bboxes.size(0)
+
+        if num_squares == 0 or num_gts == 0:
+            # No predictions and/or truth, return empty assignment
+            overlaps = approxs.new(num_gts, num_squares)
+            assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+            return assign_result
+
         # re-organize anchors by approxs_per_octave x num_squares
         approxs = torch.transpose(
             approxs.view(num_squares, approxs_per_octave, 4), 0,
diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py
index 33c761d..38a24d7 100644
--- a/mmdet/core/bbox/assigners/assign_result.py
+++ b/mmdet/core/bbox/assigners/assign_result.py
@@ -2,6 +2,41 @@ import torch
 
 
 class AssignResult(object):
+    """
+    Stores assignments between predicted and truth boxes.
+
+    Attributes:
+        num_gts (int): the number of truth boxes considered when computing this
+            assignment
+
+        gt_inds (LongTensor): for each predicted box indicates the 1-based
+            index of the assigned truth box. 0 means unassigned and -1 means
+            ignore.
+
+        max_overlaps (FloatTensor): the iou between the predicted box and its
+            assigned truth box.
+
+        labels (None | LongTensor): If specified, for each predicted box
+            indicates the category label of the assigned truth box.
+
+    Example:
+        >>> # An assign result between 4 predicted boxes and 9 true boxes
+        >>> # where only two boxes were assigned.
+        >>> num_gts = 9
+        >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
+        >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
+        >>> labels = torch.LongTensor([0, 3, 4, 0])
+        >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
+        >>> print(str(self))  # xdoctest: +IGNORE_WANT
+        <AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
+                      labels.shape=(4,))>
+        >>> # Force addition of gt labels (when adding gt as proposals)
+        >>> new_labels = torch.LongTensor([3, 4, 5])
+        >>> self.add_gt_(new_labels)
+        >>> print(str(self))  # xdoctest: +IGNORE_WANT
+        <AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
+                      labels.shape=(7,))>
+    """
 
     def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
         self.num_gts = num_gts
@@ -13,7 +48,45 @@ class AssignResult(object):
         self_inds = torch.arange(
             1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
         self.gt_inds = torch.cat([self_inds, self.gt_inds])
+
+        # Was this a bug?
+        # self.max_overlaps = torch.cat(
+        #     [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
+        # IIUC, It seems like the correct code should be:
         self.max_overlaps = torch.cat(
-            [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
+            [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
+
         if self.labels is not None:
             self.labels = torch.cat([gt_labels, self.labels])
+
+    def __nice__(self):
+        """
+        Create a "nice" summary string describing this assign result
+        """
+        parts = []
+        parts.append('num_gts={!r}'.format(self.num_gts))
+        if self.gt_inds is None:
+            parts.append('gt_inds={!r}'.format(self.gt_inds))
+        else:
+            parts.append('gt_inds.shape={!r}'.format(
+                tuple(self.gt_inds.shape)))
+        if self.max_overlaps is None:
+            parts.append('max_overlaps={!r}'.format(self.max_overlaps))
+        else:
+            parts.append('max_overlaps.shape={!r}'.format(
+                tuple(self.max_overlaps.shape)))
+        if self.labels is None:
+            parts.append('labels={!r}'.format(self.labels))
+        else:
+            parts.append('labels.shape={!r}'.format(tuple(self.labels.shape)))
+        return ', '.join(parts)
+
+    def __repr__(self):
+        nice = self.__nice__()
+        classname = self.__class__.__name__
+        return '<{}({}) at {}>'.format(classname, nice, hex(id(self)))
+
+    def __str__(self):
+        classname = self.__class__.__name__
+        nice = self.__nice__()
+        return '<{}({})>'.format(classname, nice)
diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py
index 3eb8936..93ffc42 100644
--- a/mmdet/core/bbox/assigners/max_iou_assigner.py
+++ b/mmdet/core/bbox/assigners/max_iou_assigner.py
@@ -74,9 +74,15 @@ class MaxIoUAssigner(BaseAssigner):
 
         Returns:
             :obj:`AssignResult`: The assign result.
+
+        Example:
+            >>> self = MaxIoUAssigner(0.5, 0.5)
+            >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
+            >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
+            >>> assign_result = self.assign(bboxes, gt_bboxes)
+            >>> expected_gt_inds = torch.LongTensor([1, 0])
+            >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
         """
-        if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0:
-            raise ValueError('No gt or bboxes')
         assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
             gt_bboxes.shape[0] > self.gpu_assign_thr) else False
         # compute overlap and assign gt on CPU when number of GT is large
@@ -88,6 +94,7 @@ class MaxIoUAssigner(BaseAssigner):
                 gt_bboxes_ignore = gt_bboxes_ignore.cpu()
             if gt_labels is not None:
                 gt_labels = gt_labels.cpu()
+
         bboxes = bboxes[:, :4]
         overlaps = bbox_overlaps(gt_bboxes, bboxes)
 
@@ -122,9 +129,6 @@ class MaxIoUAssigner(BaseAssigner):
         Returns:
             :obj:`AssignResult`: The assign result.
         """
-        if overlaps.numel() == 0:
-            raise ValueError('No gt or proposals')
-
         num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
 
         # 1. assign -1 by default
@@ -132,6 +136,23 @@ class MaxIoUAssigner(BaseAssigner):
                                              -1,
                                              dtype=torch.long)
 
+        if num_gts == 0 or num_bboxes == 0:
+            # No ground truth or boxes, return empty assignment
+            max_overlaps = overlaps.new_zeros((num_bboxes, ))
+            if num_gts == 0:
+                # No truth, assign everything to background
+                assigned_gt_inds[:] = 0
+            if gt_labels is None:
+                assigned_labels = None
+            else:
+                assigned_labels = overlaps.new_zeros((num_bboxes, ),
+                                                     dtype=torch.long)
+            return AssignResult(
+                num_gts,
+                assigned_gt_inds,
+                max_overlaps,
+                labels=assigned_labels)
+
         # for each anchor, which gt best overlaps with it
         # for each anchor, the max iou of all gts
         max_overlaps, argmax_overlaps = overlaps.max(dim=0)
diff --git a/mmdet/core/bbox/assigners/point_assigner.py b/mmdet/core/bbox/assigners/point_assigner.py
index fe81e7d..263b309 100644
--- a/mmdet/core/bbox/assigners/point_assigner.py
+++ b/mmdet/core/bbox/assigners/point_assigner.py
@@ -40,19 +40,33 @@ class PointAssigner(BaseAssigner):
             gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
             gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
                 labelled as `ignored`, e.g., crowd boxes in COCO.
+                NOTE: currently unused.
             gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
 
         Returns:
             :obj:`AssignResult`: The assign result.
         """
-        if points.shape[0] == 0 or gt_bboxes.shape[0] == 0:
-            raise ValueError('No gt or bboxes')
+        num_points = points.shape[0]
+        num_gts = gt_bboxes.shape[0]
+
+        if num_gts == 0 or num_points == 0:
+            # If no truth assign everything to the background
+            assigned_gt_inds = points.new_full((num_points, ),
+                                               0,
+                                               dtype=torch.long)
+            if gt_labels is None:
+                assigned_labels = None
+            else:
+                assigned_labels = points.new_zeros((num_points, ),
+                                                   dtype=torch.long)
+            return AssignResult(
+                num_gts, assigned_gt_inds, None, labels=assigned_labels)
+
         points_xy = points[:, :2]
         points_stride = points[:, 2]
         points_lvl = torch.log2(
             points_stride).int()  # [3...,4...,5...,6...,7...]
         lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
-        num_gts, num_points = gt_bboxes.shape[0], points.shape[0]
 
         # assign gt box
         gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
diff --git a/mmdet/core/bbox/geometry.py b/mmdet/core/bbox/geometry.py
index 3bc8dae..ff7c5d4 100644
--- a/mmdet/core/bbox/geometry.py
+++ b/mmdet/core/bbox/geometry.py
@@ -9,14 +9,39 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False):
     bboxes1 and bboxes2.
 
     Args:
-        bboxes1 (Tensor): shape (m, 4)
-        bboxes2 (Tensor): shape (n, 4), if is_aligned is ``True``, then m and n
-            must be equal.
+        bboxes1 (Tensor): shape (m, 4) in <x1, y1, x2, y2> format.
+        bboxes2 (Tensor): shape (n, 4) in <x1, y1, x2, y2> format.
+            If is_aligned is ``True``, then m and n must be equal.
         mode (str): "iou" (intersection over union) or iof (intersection over
             foreground).
 
     Returns:
         ious(Tensor): shape (m, n) if is_aligned == False else shape (m, 1)
+
+    Example:
+        >>> bboxes1 = torch.FloatTensor([
+        >>>     [0, 0, 10, 10],
+        >>>     [10, 10, 20, 20],
+        >>>     [32, 32, 38, 42],
+        >>> ])
+        >>> bboxes2 = torch.FloatTensor([
+        >>>     [0, 0, 10, 20],
+        >>>     [0, 10, 10, 19],
+        >>>     [10, 10, 20, 20],
+        >>> ])
+        >>> bbox_overlaps(bboxes1, bboxes2)
+        tensor([[0.5238, 0.0500, 0.0041],
+                [0.0323, 0.0452, 1.0000],
+                [0.0000, 0.0000, 0.0000]])
+
+    Example:
+        >>> empty = torch.FloatTensor([])
+        >>> nonempty = torch.FloatTensor([
+        >>>     [0, 0, 10, 9],
+        >>> ])
+        >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+        >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+        >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
     """
 
     assert mode in ['iou', 'iof']
diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py
index 12df013..a396a8d 100644
--- a/mmdet/core/bbox/samplers/base_sampler.py
+++ b/mmdet/core/bbox/samplers/base_sampler.py
@@ -51,7 +51,7 @@ class BaseSampler(metaclass=ABCMeta):
         bboxes = bboxes[:, :4]
 
         gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
-        if self.add_gt_as_proposals:
+        if self.add_gt_as_proposals and len(gt_bboxes) > 0:
             bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
             assign_result.add_gt_(gt_labels)
             gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py
index 23c9120..d041532 100644
--- a/mmdet/datasets/coco.py
+++ b/mmdet/datasets/coco.py
@@ -49,7 +49,7 @@ class CocoDataset(CustomDataset):
         valid_inds = []
         ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
         for i, img_info in enumerate(self.img_infos):
-            if self.img_ids[i] not in ids_with_ann:
+            if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann:
                 continue
             if min(img_info['width'], img_info['height']) >= min_size:
                 valid_inds.append(i)
diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py
index 84a0191..d068543 100644
--- a/mmdet/datasets/custom.py
+++ b/mmdet/datasets/custom.py
@@ -40,13 +40,15 @@ class CustomDataset(Dataset):
                  img_prefix='',
                  seg_prefix=None,
                  proposal_file=None,
-                 test_mode=False):
+                 test_mode=False,
+                 filter_empty_gt=True):
         self.ann_file = ann_file
         self.data_root = data_root
         self.img_prefix = img_prefix
         self.seg_prefix = seg_prefix
         self.proposal_file = proposal_file
         self.test_mode = test_mode
+        self.filter_empty_gt = filter_empty_gt
 
         # join paths if data_root is specified
         if self.data_root is not None:
@@ -66,7 +68,7 @@ class CustomDataset(Dataset):
             self.proposals = self.load_proposals(self.proposal_file)
         else:
             self.proposals = None
-        # filter images with no annotation during training
+        # filter images too small
         if not test_mode:
             valid_inds = self._filter_imgs()
             self.img_infos = [self.img_infos[i] for i in valid_inds]
diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py
index 6a9938c..9f3007e 100644
--- a/mmdet/datasets/pipelines/loading.py
+++ b/mmdet/datasets/pipelines/loading.py
@@ -1,5 +1,4 @@
 import os.path as osp
-import warnings
 
 import mmcv
 import numpy as np
@@ -42,28 +41,16 @@ class LoadAnnotations(object):
                  with_label=True,
                  with_mask=False,
                  with_seg=False,
-                 poly2mask=True,
-                 skip_img_without_anno=True):
+                 poly2mask=True):
         self.with_bbox = with_bbox
         self.with_label = with_label
         self.with_mask = with_mask
         self.with_seg = with_seg
         self.poly2mask = poly2mask
-        self.skip_img_without_anno = skip_img_without_anno
 
     def _load_bboxes(self, results):
         ann_info = results['ann_info']
         results['gt_bboxes'] = ann_info['bboxes']
-        if len(results['gt_bboxes']) == 0 and self.skip_img_without_anno:
-            if results['img_prefix'] is not None:
-                file_path = osp.join(results['img_prefix'],
-                                     results['img_info']['filename'])
-            else:
-                file_path = results['img_info']['filename']
-            warnings.warn(
-                'Skip the image "{}" that has no valid gt bbox'.format(
-                    file_path))
-            return None
 
         gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
         if gt_bboxes_ignore is not None:
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
index 760b3b1..dc38597 100644
--- a/mmdet/datasets/pipelines/transforms.py
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -275,7 +275,10 @@ class Pad(object):
                 mmcv.impad(mask, pad_shape, pad_val=self.pad_val)
                 for mask in results[key]
             ]
-            results[key] = np.stack(padded_masks, axis=0)
+            if padded_masks:
+                results[key] = np.stack(padded_masks, axis=0)
+            else:
+                results[key] = np.empty((0, ) + pad_shape, dtype=np.uint8)
 
     def __call__(self, results):
         self._pad_img(results)
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 2e983ff..ced0ad1 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -107,26 +107,30 @@ class BBoxHead(nn.Module):
         losses = dict()
         if cls_score is not None:
             avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
-            losses['loss_cls'] = self.loss_cls(
-                cls_score,
-                labels,
-                label_weights,
-                avg_factor=avg_factor,
-                reduction_override=reduction_override)
-            losses['acc'] = accuracy(cls_score, labels)
+            if cls_score.numel() > 0:
+                losses['loss_cls'] = self.loss_cls(
+                    cls_score,
+                    labels,
+                    label_weights,
+                    avg_factor=avg_factor,
+                    reduction_override=reduction_override)
+                losses['acc'] = accuracy(cls_score, labels)
         if bbox_pred is not None:
             pos_inds = labels > 0
-            if self.reg_class_agnostic:
-                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds]
-            else:
-                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
-                                               4)[pos_inds, labels[pos_inds]]
-            losses['loss_bbox'] = self.loss_bbox(
-                pos_bbox_pred,
-                bbox_targets[pos_inds],
-                bbox_weights[pos_inds],
-                avg_factor=bbox_targets.size(0),
-                reduction_override=reduction_override)
+            if pos_inds.any():
+                if self.reg_class_agnostic:
+                    pos_bbox_pred = bbox_pred.view(bbox_pred.size(0),
+                                                   4)[pos_inds]
+                else:
+                    pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
+                                                   4)[pos_inds,
+                                                      labels[pos_inds]]
+                losses['loss_bbox'] = self.loss_bbox(
+                    pos_bbox_pred,
+                    bbox_targets[pos_inds],
+                    bbox_weights[pos_inds],
+                    avg_factor=bbox_targets.size(0),
+                    reduction_override=reduction_override)
         return losses
 
     @force_fp32(apply_to=('cls_score', 'bbox_pred'))
diff --git a/mmdet/models/bbox_heads/convfc_bbox_head.py b/mmdet/models/bbox_heads/convfc_bbox_head.py
index 777c455..f0f8977 100644
--- a/mmdet/models/bbox_heads/convfc_bbox_head.py
+++ b/mmdet/models/bbox_heads/convfc_bbox_head.py
@@ -138,7 +138,9 @@ class ConvFCBBoxHead(BBoxHead):
         if self.num_shared_fcs > 0:
             if self.with_avg_pool:
                 x = self.avg_pool(x)
-            x = x.view(x.size(0), -1)
+
+            x = x.flatten(1)
+
             for fc in self.shared_fcs:
                 x = self.relu(fc(x))
         # separate branches
@@ -150,7 +152,7 @@ class ConvFCBBoxHead(BBoxHead):
         if x_cls.dim() > 2:
             if self.with_avg_pool:
                 x_cls = self.avg_pool(x_cls)
-            x_cls = x_cls.view(x_cls.size(0), -1)
+            x_cls = x_cls.flatten(1)
         for fc in self.cls_fcs:
             x_cls = self.relu(fc(x_cls))
 
@@ -159,7 +161,7 @@ class ConvFCBBoxHead(BBoxHead):
         if x_reg.dim() > 2:
             if self.with_avg_pool:
                 x_reg = self.avg_pool(x_reg)
-            x_reg = x_reg.view(x_reg.size(0), -1)
+            x_reg = x_reg.flatten(1)
         for fc in self.reg_fcs:
             x_reg = self.relu(fc(x_reg))
 
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index e79b189..4ab1e57 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -236,6 +236,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             bbox_head = self.bbox_head[i]
 
             rois = bbox2roi([res.bboxes for res in sampling_results])
+
+            if len(rois) == 0:
+                # If there are no predicted and/or truth boxes, then we cannot
+                # compute head / mask losses
+                continue
+
             bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
                                             rois)
             if self.with_shared_head:
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index 1558195..962e0cb 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -247,16 +247,16 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                             dtype=torch.uint8))
                 pos_inds = torch.cat(pos_inds)
                 mask_feats = bbox_feats[pos_inds]
-            mask_pred = self.mask_head(mask_feats)
 
-            mask_targets = self.mask_head.get_target(sampling_results,
-                                                     gt_masks,
-                                                     self.train_cfg.rcnn)
-            pos_labels = torch.cat(
-                [res.pos_gt_labels for res in sampling_results])
-            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
-                                            pos_labels)
-            losses.update(loss_mask)
+            if mask_feats.shape[0] > 0:
+                mask_pred = self.mask_head(mask_feats)
+                mask_targets = self.mask_head.get_target(
+                    sampling_results, gt_masks, self.train_cfg.rcnn)
+                pos_labels = torch.cat(
+                    [res.pos_gt_labels for res in sampling_results])
+                loss_mask = self.mask_head.loss(mask_pred, mask_targets,
+                                                pos_labels)
+                losses.update(loss_mask)
 
         return losses
 
diff --git a/tests/test_assigner.py b/tests/test_assigner.py
new file mode 100644
index 0000000..50cf7d5
--- /dev/null
+++ b/tests/test_assigner.py
@@ -0,0 +1,261 @@
+"""
+Tests the Assigner objects.
+
+CommandLine:
+    pytest tests/test_assigner.py
+    xdoctest tests/test_assigner.py zero
+
+
+
+"""
+import torch
+
+from mmdet.core import MaxIoUAssigner
+from mmdet.core.bbox.assigners import ApproxMaxIoUAssigner, PointAssigner
+
+
+def test_max_iou_assigner():
+    self = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    gt_labels = torch.LongTensor([2, 3])
+    assign_result = self.assign(bboxes, gt_bboxes, gt_labels=gt_labels)
+    assert len(assign_result.gt_inds) == 4
+    assert len(assign_result.labels) == 4
+
+    expected_gt_inds = torch.LongTensor([1, 0, 2, 0])
+    assert torch.all(assign_result.gt_inds == expected_gt_inds)
+
+
+def test_max_iou_assigner_with_ignore():
+    self = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        ignore_iof_thr=0.5,
+        ignore_wrt_candidates=False,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    gt_bboxes_ignore = torch.Tensor([
+        [30, 30, 40, 40],
+    ])
+    assign_result = self.assign(
+        bboxes, gt_bboxes, gt_bboxes_ignore=gt_bboxes_ignore)
+
+    expected_gt_inds = torch.LongTensor([1, 0, 2, -1])
+    assert torch.all(assign_result.gt_inds == expected_gt_inds)
+
+
+def test_max_iou_assigner_with_empty_gt():
+    """
+    Test corner case where an image might have no true detections
+    """
+    self = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.FloatTensor([])
+    assign_result = self.assign(bboxes, gt_bboxes)
+
+    expected_gt_inds = torch.LongTensor([0, 0, 0, 0])
+    assert torch.all(assign_result.gt_inds == expected_gt_inds)
+
+
+def test_max_iou_assigner_with_empty_boxes():
+    """
+    Test corner case where an network might predict no boxes
+    """
+    self = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.empty((0, 4))
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    gt_labels = torch.LongTensor([2, 3])
+
+    # Test with gt_labels
+    assign_result = self.assign(bboxes, gt_bboxes, gt_labels=gt_labels)
+    assert len(assign_result.gt_inds) == 0
+    assert tuple(assign_result.labels.shape) == (0, )
+
+    # Test without gt_labels
+    assign_result = self.assign(bboxes, gt_bboxes, gt_labels=None)
+    assert len(assign_result.gt_inds) == 0
+    assert assign_result.labels is None
+
+
+def test_max_iou_assigner_with_empty_boxes_and_gt():
+    """
+    Test corner case where an network might predict no boxes and no gt
+    """
+    self = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.empty((0, 4))
+    gt_bboxes = torch.empty((0, 4))
+    assign_result = self.assign(bboxes, gt_bboxes)
+    assert len(assign_result.gt_inds) == 0
+
+
+def test_point_assigner():
+    self = PointAssigner()
+    points = torch.FloatTensor([  # [x, y, stride]
+        [0, 0, 1],
+        [10, 10, 1],
+        [5, 5, 1],
+        [32, 32, 1],
+    ])
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    assign_result = self.assign(points, gt_bboxes)
+    expected_gt_inds = torch.LongTensor([1, 2, 1, 0])
+    assert torch.all(assign_result.gt_inds == expected_gt_inds)
+
+
+def test_point_assigner_with_empty_gt():
+    """
+    Test corner case where an image might have no true detections
+    """
+    self = PointAssigner()
+    points = torch.FloatTensor([  # [x, y, stride]
+        [0, 0, 1],
+        [10, 10, 1],
+        [5, 5, 1],
+        [32, 32, 1],
+    ])
+    gt_bboxes = torch.FloatTensor([])
+    assign_result = self.assign(points, gt_bboxes)
+
+    expected_gt_inds = torch.LongTensor([0, 0, 0, 0])
+    assert torch.all(assign_result.gt_inds == expected_gt_inds)
+
+
+def test_point_assigner_with_empty_boxes_and_gt():
+    """
+    Test corner case where an image might predict no points and no gt
+    """
+    self = PointAssigner()
+    points = torch.FloatTensor([])
+    gt_bboxes = torch.FloatTensor([])
+    assign_result = self.assign(points, gt_bboxes)
+    assert len(assign_result.gt_inds) == 0
+
+
+def test_approx_iou_assigner():
+    self = ApproxMaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    approxs_per_octave = 1
+    approxs = bboxes
+    squares = bboxes
+    assign_result = self.assign(approxs, squares, approxs_per_octave,
+                                gt_bboxes)
+
+    expected_gt_inds = torch.LongTensor([1, 0, 2, 0])
+    assert torch.all(assign_result.gt_inds == expected_gt_inds)
+
+
+def test_approx_iou_assigner_with_empty_gt():
+    """
+    Test corner case where an image might have no true detections
+    """
+    self = ApproxMaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.FloatTensor([])
+    approxs_per_octave = 1
+    approxs = bboxes
+    squares = bboxes
+    assign_result = self.assign(approxs, squares, approxs_per_octave,
+                                gt_bboxes)
+
+    expected_gt_inds = torch.LongTensor([0, 0, 0, 0])
+    assert torch.all(assign_result.gt_inds == expected_gt_inds)
+
+
+def test_approx_iou_assigner_with_empty_boxes():
+    """
+    Test corner case where an network might predict no boxes
+    """
+    self = ApproxMaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.empty((0, 4))
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    approxs_per_octave = 1
+    approxs = bboxes
+    squares = bboxes
+    assign_result = self.assign(approxs, squares, approxs_per_octave,
+                                gt_bboxes)
+    assert len(assign_result.gt_inds) == 0
+
+
+def test_approx_iou_assigner_with_empty_boxes_and_gt():
+    """
+    Test corner case where an network might predict no boxes and no gt
+    """
+    self = ApproxMaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+    )
+    bboxes = torch.empty((0, 4))
+    gt_bboxes = torch.empty((0, 4))
+    approxs_per_octave = 1
+    approxs = bboxes
+    squares = bboxes
+    assign_result = self.assign(approxs, squares, approxs_per_octave,
+                                gt_bboxes)
+    assert len(assign_result.gt_inds) == 0
diff --git a/tests/test_forward.py b/tests/test_forward.py
index dede4ce..5ba56bf 100644
--- a/tests/test_forward.py
+++ b/tests/test_forward.py
@@ -168,6 +168,164 @@ def test_retina_ghm_forward():
                 batch_results.append(result)
 
 
+def test_cascade_forward():
+    try:
+        from torchvision import _C as C  # NOQA
+    except ImportError:
+        import pytest
+        raise pytest.skip('requires torchvision on cpu')
+
+    model, train_cfg, test_cfg = _get_detector_cfg(
+        'cascade_rcnn_r50_fpn_1x.py')
+    model['pretrained'] = None
+    # torchvision roi align supports CPU
+    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+
+    from mmdet.models import build_detector
+    detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
+
+    input_shape = (1, 3, 256, 256)
+
+    # Test forward train with a non-empty truth batch
+    mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
+    imgs = mm_inputs.pop('imgs')
+    img_metas = mm_inputs.pop('img_metas')
+    gt_bboxes = mm_inputs['gt_bboxes']
+    gt_labels = mm_inputs['gt_labels']
+    losses = detector.forward(
+        imgs,
+        img_metas,
+        gt_bboxes=gt_bboxes,
+        gt_labels=gt_labels,
+        return_loss=True)
+    assert isinstance(losses, dict)
+    from mmdet.apis.train import parse_losses
+    total_loss = float(parse_losses(losses)[0].item())
+    assert total_loss > 0
+
+    # Test forward train with an empty truth batch
+    mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
+    imgs = mm_inputs.pop('imgs')
+    img_metas = mm_inputs.pop('img_metas')
+    gt_bboxes = mm_inputs['gt_bboxes']
+    gt_labels = mm_inputs['gt_labels']
+    losses = detector.forward(
+        imgs,
+        img_metas,
+        gt_bboxes=gt_bboxes,
+        gt_labels=gt_labels,
+        return_loss=True)
+    assert isinstance(losses, dict)
+    from mmdet.apis.train import parse_losses
+    total_loss = float(parse_losses(losses)[0].item())
+    assert total_loss > 0
+
+
+def test_faster_rcnn_forward():
+    try:
+        from torchvision import _C as C  # NOQA
+    except ImportError:
+        import pytest
+        raise pytest.skip('requires torchvision on cpu')
+
+    model, train_cfg, test_cfg = _get_detector_cfg('faster_rcnn_r50_fpn_1x.py')
+    model['pretrained'] = None
+    # torchvision roi align supports CPU
+    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+
+    from mmdet.models import build_detector
+    detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
+
+    input_shape = (1, 3, 256, 256)
+
+    # Test forward train with a non-empty truth batch
+    mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
+    imgs = mm_inputs.pop('imgs')
+    img_metas = mm_inputs.pop('img_metas')
+    gt_bboxes = mm_inputs['gt_bboxes']
+    gt_labels = mm_inputs['gt_labels']
+    losses = detector.forward(
+        imgs,
+        img_metas,
+        gt_bboxes=gt_bboxes,
+        gt_labels=gt_labels,
+        return_loss=True)
+    assert isinstance(losses, dict)
+    from mmdet.apis.train import parse_losses
+    total_loss = float(parse_losses(losses)[0].item())
+    assert total_loss > 0
+
+    # Test forward train with an empty truth batch
+    mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
+    imgs = mm_inputs.pop('imgs')
+    img_metas = mm_inputs.pop('img_metas')
+    gt_bboxes = mm_inputs['gt_bboxes']
+    gt_labels = mm_inputs['gt_labels']
+    losses = detector.forward(
+        imgs,
+        img_metas,
+        gt_bboxes=gt_bboxes,
+        gt_labels=gt_labels,
+        return_loss=True)
+    assert isinstance(losses, dict)
+    from mmdet.apis.train import parse_losses
+    total_loss = float(parse_losses(losses)[0].item())
+    assert total_loss > 0
+
+
+def test_faster_rcnn_ohem_forward():
+    try:
+        from torchvision import _C as C  # NOQA
+    except ImportError:
+        import pytest
+        raise pytest.skip('requires torchvision on cpu')
+
+    model, train_cfg, test_cfg = _get_detector_cfg(
+        'faster_rcnn_ohem_r50_fpn_1x.py')
+    model['pretrained'] = None
+    # torchvision roi align supports CPU
+    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+
+    from mmdet.models import build_detector
+    detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
+
+    input_shape = (1, 3, 256, 256)
+
+    # Test forward train with a non-empty truth batch
+    mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
+    imgs = mm_inputs.pop('imgs')
+    img_metas = mm_inputs.pop('img_metas')
+    gt_bboxes = mm_inputs['gt_bboxes']
+    gt_labels = mm_inputs['gt_labels']
+    losses = detector.forward(
+        imgs,
+        img_metas,
+        gt_bboxes=gt_bboxes,
+        gt_labels=gt_labels,
+        return_loss=True)
+    assert isinstance(losses, dict)
+    from mmdet.apis.train import parse_losses
+    total_loss = float(parse_losses(losses)[0].item())
+    assert total_loss > 0
+
+    # Test forward train with an empty truth batch
+    mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
+    imgs = mm_inputs.pop('imgs')
+    img_metas = mm_inputs.pop('img_metas')
+    gt_bboxes = mm_inputs['gt_bboxes']
+    gt_labels = mm_inputs['gt_labels']
+    losses = detector.forward(
+        imgs,
+        img_metas,
+        gt_bboxes=gt_bboxes,
+        gt_labels=gt_labels,
+        return_loss=True)
+    assert isinstance(losses, dict)
+    from mmdet.apis.train import parse_losses
+    total_loss = float(parse_losses(losses)[0].item())
+    assert total_loss > 0
+
+
 def _demo_mm_inputs(input_shape=(1, 3, 300, 300),
                     num_items=None, num_classes=10):  # yapf: disable
     """
diff --git a/tests/test_heads.py b/tests/test_heads.py
new file mode 100644
index 0000000..5c14314
--- /dev/null
+++ b/tests/test_heads.py
@@ -0,0 +1,171 @@
+import mmcv
+import torch
+
+from mmdet.core import build_assigner, build_sampler
+from mmdet.models.anchor_heads import AnchorHead
+from mmdet.models.bbox_heads import BBoxHead
+
+
+def test_anchor_head_loss():
+    """
+    Tests anchor head loss when truth is empty and non-empty
+    """
+    self = AnchorHead(num_classes=4, in_channels=1)
+    s = 256
+    img_metas = [{
+        'img_shape': (s, s, 3),
+        'scale_factor': 1,
+        'pad_shape': (s, s, 3)
+    }]
+
+    cfg = mmcv.Config({
+        'assigner': {
+            'type': 'MaxIoUAssigner',
+            'pos_iou_thr': 0.7,
+            'neg_iou_thr': 0.3,
+            'min_pos_iou': 0.3,
+            'ignore_iof_thr': -1
+        },
+        'sampler': {
+            'type': 'RandomSampler',
+            'num': 256,
+            'pos_fraction': 0.5,
+            'neg_pos_ub': -1,
+            'add_gt_as_proposals': False
+        },
+        'allowed_border': 0,
+        'pos_weight': -1,
+        'debug': False
+    })
+
+    # Anchor head expects a multiple levels of features per image
+    feat = [
+        torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2)))
+        for i in range(len(self.anchor_generators))
+    ]
+    cls_scores, bbox_preds = self.forward(feat)
+
+    # Test that empty ground truth encourages the network to predict background
+    gt_bboxes = [torch.empty((0, 4))]
+    gt_labels = [torch.LongTensor([])]
+
+    gt_bboxes_ignore = None
+    empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
+                                img_metas, cfg, gt_bboxes_ignore)
+    # When there is no truth, the cls loss should be nonzero but there should
+    # be no box loss.
+    empty_cls_loss = sum(empty_gt_losses['loss_cls'])
+    empty_box_loss = sum(empty_gt_losses['loss_bbox'])
+    assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
+    assert empty_box_loss.item() == 0, (
+        'there should be no box loss when there are no true boxes')
+
+    # When truth is non-empty then both cls and box loss should be nonzero for
+    # random inputs
+    gt_bboxes = [
+        torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
+    ]
+    gt_labels = [torch.LongTensor([2])]
+    one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
+                              img_metas, cfg, gt_bboxes_ignore)
+    onegt_cls_loss = sum(one_gt_losses['loss_cls'])
+    onegt_box_loss = sum(one_gt_losses['loss_bbox'])
+    assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
+    assert onegt_box_loss.item() > 0, 'box loss should be non-zero'
+
+
+def test_bbox_head_loss():
+    """
+    Tests bbox head loss when truth is empty and non-empty
+    """
+    self = BBoxHead(in_channels=8, roi_feat_size=3)
+
+    num_imgs = 1
+    feat = torch.rand(1, 1, 3, 3)
+
+    # Dummy proposals
+    proposal_list = [
+        torch.Tensor([[23.6667, 23.8757, 228.6326, 153.8874]]),
+    ]
+
+    target_cfg = mmcv.Config({'pos_weight': 1})
+
+    def _dummy_bbox_sampling(proposal_list, gt_bboxes, gt_labels):
+        """
+        Create sample results that can be passed to BBoxHead.get_target
+        """
+        assign_config = {
+            'type': 'MaxIoUAssigner',
+            'pos_iou_thr': 0.5,
+            'neg_iou_thr': 0.5,
+            'min_pos_iou': 0.5,
+            'ignore_iof_thr': -1
+        }
+        sampler_config = {
+            'type': 'RandomSampler',
+            'num': 512,
+            'pos_fraction': 0.25,
+            'neg_pos_ub': -1,
+            'add_gt_as_proposals': True
+        }
+        bbox_assigner = build_assigner(assign_config)
+        bbox_sampler = build_sampler(sampler_config)
+        gt_bboxes_ignore = [None for _ in range(num_imgs)]
+        sampling_results = []
+        for i in range(num_imgs):
+            assign_result = bbox_assigner.assign(proposal_list[i],
+                                                 gt_bboxes[i],
+                                                 gt_bboxes_ignore[i],
+                                                 gt_labels[i])
+            sampling_result = bbox_sampler.sample(
+                assign_result,
+                proposal_list[i],
+                gt_bboxes[i],
+                gt_labels[i],
+                feats=feat)
+            sampling_results.append(sampling_result)
+        return sampling_results
+
+    # Test bbox loss when truth is empty
+    gt_bboxes = [torch.empty((0, 4))]
+    gt_labels = [torch.LongTensor([])]
+
+    sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes,
+                                            gt_labels)
+
+    bbox_targets = self.get_target(sampling_results, gt_bboxes, gt_labels,
+                                   target_cfg)
+    labels, label_weights, bbox_targets, bbox_weights = bbox_targets
+
+    # Create dummy features "extracted" for each sampled bbox
+    num_sampled = sum(len(res.bboxes) for res in sampling_results)
+    dummy_feats = torch.rand(num_sampled, 8 * 3 * 3)
+    cls_scores, bbox_preds = self.forward(dummy_feats)
+
+    losses = self.loss(cls_scores, bbox_preds, labels, label_weights,
+                       bbox_targets, bbox_weights)
+    assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero'
+    assert losses.get('loss_bbox', 0) == 0, 'empty gt loss should be zero'
+
+    # Test bbox loss when truth is non-empty
+    gt_bboxes = [
+        torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
+    ]
+    gt_labels = [torch.LongTensor([2])]
+
+    sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes,
+                                            gt_labels)
+
+    bbox_targets = self.get_target(sampling_results, gt_bboxes, gt_labels,
+                                   target_cfg)
+    labels, label_weights, bbox_targets, bbox_weights = bbox_targets
+
+    # Create dummy features "extracted" for each sampled bbox
+    num_sampled = sum(len(res.bboxes) for res in sampling_results)
+    dummy_feats = torch.rand(num_sampled, 8 * 3 * 3)
+    cls_scores, bbox_preds = self.forward(dummy_feats)
+
+    losses = self.loss(cls_scores, bbox_preds, labels, label_weights,
+                       bbox_targets, bbox_weights)
+    assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero'
+    assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero'
diff --git a/tests/test_sampler.py b/tests/test_sampler.py
new file mode 100644
index 0000000..c375d6e
--- /dev/null
+++ b/tests/test_sampler.py
@@ -0,0 +1,235 @@
+import torch
+
+from mmdet.core import MaxIoUAssigner
+from mmdet.core.bbox.samplers import OHEMSampler, RandomSampler
+
+
+def test_random_sampler():
+    assigner = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        ignore_iof_thr=0.5,
+        ignore_wrt_candidates=False,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    gt_labels = torch.LongTensor([1, 2])
+    gt_bboxes_ignore = torch.Tensor([
+        [30, 30, 40, 40],
+    ])
+    assign_result = assigner.assign(
+        bboxes,
+        gt_bboxes,
+        gt_bboxes_ignore=gt_bboxes_ignore,
+        gt_labels=gt_labels)
+
+    sampler = RandomSampler(
+        num=10, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=True)
+
+    sample_result = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+
+    assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
+    assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
+
+
+def test_random_sampler_empty_gt():
+    assigner = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        ignore_iof_thr=0.5,
+        ignore_wrt_candidates=False,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.empty(0, 4)
+    gt_labels = torch.empty(0, ).long()
+    assign_result = assigner.assign(bboxes, gt_bboxes, gt_labels=gt_labels)
+
+    sampler = RandomSampler(
+        num=10, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=True)
+
+    sample_result = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+
+    assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
+    assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
+
+
+def test_random_sampler_empty_pred():
+    assigner = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        ignore_iof_thr=0.5,
+        ignore_wrt_candidates=False,
+    )
+    bboxes = torch.empty(0, 4)
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    gt_labels = torch.LongTensor([1, 2])
+    assign_result = assigner.assign(bboxes, gt_bboxes, gt_labels=gt_labels)
+
+    sampler = RandomSampler(
+        num=10, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=True)
+
+    sample_result = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+
+    assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
+    assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
+
+
+def _context_for_ohem():
+    try:
+        from test_forward import _get_detector_cfg
+    except ImportError:
+        # Hack: grab testing utils from test_forward to make a context for ohem
+        import sys
+        from os.path import dirname
+        sys.path.insert(0, dirname(__file__))
+        from test_forward import _get_detector_cfg
+    model, train_cfg, test_cfg = _get_detector_cfg(
+        'faster_rcnn_ohem_r50_fpn_1x.py')
+    model['pretrained'] = None
+    # torchvision roi align supports CPU
+    model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True
+    from mmdet.models import build_detector
+    context = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)
+    return context
+
+
+def test_ohem_sampler():
+
+    assigner = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        ignore_iof_thr=0.5,
+        ignore_wrt_candidates=False,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 9],
+        [0, 10, 10, 19],
+    ])
+    gt_labels = torch.LongTensor([1, 2])
+    gt_bboxes_ignore = torch.Tensor([
+        [30, 30, 40, 40],
+    ])
+    assign_result = assigner.assign(
+        bboxes,
+        gt_bboxes,
+        gt_bboxes_ignore=gt_bboxes_ignore,
+        gt_labels=gt_labels)
+
+    context = _context_for_ohem()
+
+    sampler = OHEMSampler(
+        num=10,
+        pos_fraction=0.5,
+        context=context,
+        neg_pos_ub=-1,
+        add_gt_as_proposals=True)
+
+    feats = [torch.rand(1, 256, int(2**i), int(2**i)) for i in [6, 5, 4, 3, 2]]
+    sample_result = sampler.sample(
+        assign_result, bboxes, gt_bboxes, gt_labels, feats=feats)
+
+    assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
+    assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
+
+
+def test_ohem_sampler_empty_gt():
+
+    assigner = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        ignore_iof_thr=0.5,
+        ignore_wrt_candidates=False,
+    )
+    bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_bboxes = torch.empty(0, 4)
+    gt_labels = torch.LongTensor([])
+    gt_bboxes_ignore = torch.Tensor([])
+    assign_result = assigner.assign(
+        bboxes,
+        gt_bboxes,
+        gt_bboxes_ignore=gt_bboxes_ignore,
+        gt_labels=gt_labels)
+
+    context = _context_for_ohem()
+
+    sampler = OHEMSampler(
+        num=10,
+        pos_fraction=0.5,
+        context=context,
+        neg_pos_ub=-1,
+        add_gt_as_proposals=True)
+
+    feats = [torch.rand(1, 256, int(2**i), int(2**i)) for i in [6, 5, 4, 3, 2]]
+
+    sample_result = sampler.sample(
+        assign_result, bboxes, gt_bboxes, gt_labels, feats=feats)
+
+    assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
+    assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
+
+
+def test_ohem_sampler_empty_pred():
+    assigner = MaxIoUAssigner(
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.5,
+        ignore_iof_thr=0.5,
+        ignore_wrt_candidates=False,
+    )
+    bboxes = torch.empty(0, 4)
+    gt_bboxes = torch.FloatTensor([
+        [0, 0, 10, 10],
+        [10, 10, 20, 20],
+        [5, 5, 15, 15],
+        [32, 32, 38, 42],
+    ])
+    gt_labels = torch.LongTensor([1, 2, 2, 3])
+    gt_bboxes_ignore = torch.Tensor([])
+    assign_result = assigner.assign(
+        bboxes,
+        gt_bboxes,
+        gt_bboxes_ignore=gt_bboxes_ignore,
+        gt_labels=gt_labels)
+
+    context = _context_for_ohem()
+
+    sampler = OHEMSampler(
+        num=10,
+        pos_fraction=0.5,
+        context=context,
+        neg_pos_ub=-1,
+        add_gt_as_proposals=True)
+
+    feats = [torch.rand(1, 256, int(2**i), int(2**i)) for i in [6, 5, 4, 3, 2]]
+
+    sample_result = sampler.sample(
+        assign_result, bboxes, gt_bboxes, gt_labels, feats=feats)
+
+    assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
+    assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
-- 
GitLab