diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py
index 38a24d7e60ceadedb7fa5f4daa2016edf3fbc3bc..5e81c89782060549a403bf7f0959d44f35071b0a 100644
--- a/mmdet/core/bbox/assigners/assign_result.py
+++ b/mmdet/core/bbox/assigners/assign_result.py
@@ -1,7 +1,9 @@
 import torch
 
+from mmdet.utils import util_mixins
 
-class AssignResult(object):
+
+class AssignResult(util_mixins.NiceRepr):
     """
     Stores assignments between predicted and truth boxes.
 
@@ -44,20 +46,25 @@ class AssignResult(object):
         self.max_overlaps = max_overlaps
         self.labels = labels
 
-    def add_gt_(self, gt_labels):
-        self_inds = torch.arange(
-            1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
-        self.gt_inds = torch.cat([self_inds, self.gt_inds])
-
-        # Was this a bug?
-        # self.max_overlaps = torch.cat(
-        #     [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
-        # IIUC, It seems like the correct code should be:
-        self.max_overlaps = torch.cat(
-            [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
+    @property
+    def num_preds(self):
+        """
+        Return the number of predictions in this assignment
+        """
+        return len(self.gt_inds)
 
-        if self.labels is not None:
-            self.labels = torch.cat([gt_labels, self.labels])
+    @property
+    def info(self):
+        """
+        Returns a dictionary of info about the object
+        """
+        return {
+            'num_gts': self.num_gts,
+            'num_preds': self.num_preds,
+            'gt_inds': self.gt_inds,
+            'max_overlaps': self.max_overlaps,
+            'labels': self.labels,
+        }
 
     def __nice__(self):
         """
@@ -81,12 +88,105 @@ class AssignResult(object):
             parts.append('labels.shape={!r}'.format(tuple(self.labels.shape)))
         return ', '.join(parts)
 
-    def __repr__(self):
-        nice = self.__nice__()
-        classname = self.__class__.__name__
-        return '<{}({}) at {}>'.format(classname, nice, hex(id(self)))
+    @classmethod
+    def random(cls, **kwargs):
+        """
+        Create random AssignResult for tests or debugging.
+
+        Kwargs:
+            num_preds: number of predicted boxes
+            num_gts: number of true boxes
+            p_ignore (float): probability of a predicted box assinged to an
+                ignored truth
+            p_assigned (float): probability of a predicted box not being
+                assigned
+            p_use_label (float | bool): with labels or not
+            rng (None | int | numpy.random.RandomState): seed or state
+
+        Returns:
+            AssignResult :
+
+        Example:
+            >>> from mmdet.core.bbox.assigners.assign_result import *  # NOQA
+            >>> self = AssignResult.random()
+            >>> print(self.info)
+        """
+        from mmdet.core.bbox import demodata
+        rng = demodata.ensure_rng(kwargs.get('rng', None))
+
+        num_gts = kwargs.get('num_gts', None)
+        num_preds = kwargs.get('num_preds', None)
+        p_ignore = kwargs.get('p_ignore', 0.3)
+        p_assigned = kwargs.get('p_assigned', 0.7)
+        p_use_label = kwargs.get('p_use_label', 0.5)
+        num_classes = kwargs.get('p_use_label', 3)
+
+        if num_gts is None:
+            num_gts = rng.randint(0, 8)
+        if num_preds is None:
+            num_preds = rng.randint(0, 16)
 
-    def __str__(self):
-        classname = self.__class__.__name__
-        nice = self.__nice__()
-        return '<{}({})>'.format(classname, nice)
+        if num_gts == 0:
+            max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
+            gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+            if p_use_label is True or p_use_label < rng.rand():
+                labels = torch.zeros(num_preds, dtype=torch.int64)
+            else:
+                labels = None
+        else:
+            import numpy as np
+            # Create an overlap for each predicted box
+            max_overlaps = torch.from_numpy(rng.rand(num_preds))
+
+            # Construct gt_inds for each predicted box
+            is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
+            # maximum number of assignments constraints
+            n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
+
+            assigned_idxs = np.where(is_assigned)[0]
+            rng.shuffle(assigned_idxs)
+            assigned_idxs = assigned_idxs[0:n_assigned]
+            assigned_idxs.sort()
+
+            is_assigned[:] = 0
+            is_assigned[assigned_idxs] = True
+
+            is_ignore = torch.from_numpy(
+                rng.rand(num_preds) < p_ignore) & is_assigned
+
+            gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+
+            true_idxs = np.arange(num_gts)
+            rng.shuffle(true_idxs)
+            true_idxs = torch.from_numpy(true_idxs)
+            gt_inds[is_assigned] = true_idxs[:n_assigned]
+
+            gt_inds = torch.from_numpy(
+                rng.randint(1, num_gts + 1, size=num_preds))
+            gt_inds[is_ignore] = -1
+            gt_inds[~is_assigned] = 0
+            max_overlaps[~is_assigned] = 0
+
+            if p_use_label is True or p_use_label < rng.rand():
+                if num_classes == 0:
+                    labels = torch.zeros(num_preds, dtype=torch.int64)
+                else:
+                    labels = torch.from_numpy(
+                        rng.randint(1, num_classes + 1, size=num_preds))
+                    labels[~is_assigned] = 0
+            else:
+                labels = None
+
+        self = cls(num_gts, gt_inds, max_overlaps, labels)
+        return self
+
+    def add_gt_(self, gt_labels):
+        self_inds = torch.arange(
+            1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
+        self.gt_inds = torch.cat([self_inds, self.gt_inds])
+
+        self.max_overlaps = torch.cat(
+            [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
+
+        if self.labels is not None:
+            self.labels = torch.cat([gt_labels, self.labels])
diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py
index a396a8d8a92d0bbbef535bf2fb1fd4407b691147..f437195f6b7c56eb93433682f3ccf61f71f49ed0 100644
--- a/mmdet/core/bbox/samplers/base_sampler.py
+++ b/mmdet/core/bbox/samplers/base_sampler.py
@@ -47,11 +47,30 @@ class BaseSampler(metaclass=ABCMeta):
 
         Returns:
             :obj:`SamplingResult`: Sampling result.
+
+        Example:
+            >>> from mmdet.core.bbox import RandomSampler
+            >>> from mmdet.core.bbox import AssignResult
+            >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
+            >>> rng = ensure_rng(None)
+            >>> assign_result = AssignResult.random(rng=rng)
+            >>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
+            >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
+            >>> gt_labels = None
+            >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
+            >>>                      add_gt_as_proposals=False)
+            >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
         """
+        if len(bboxes.shape) < 2:
+            bboxes = bboxes[None, :]
+
         bboxes = bboxes[:, :4]
 
         gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
         if self.add_gt_as_proposals and len(gt_bboxes) > 0:
+            if gt_labels is None:
+                raise ValueError(
+                    'gt_labels must be given when add_gt_as_proposals is True')
             bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
             assign_result.add_gt_(gt_labels)
             gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
@@ -74,5 +93,6 @@ class BaseSampler(metaclass=ABCMeta):
             assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
         neg_inds = neg_inds.unique()
 
-        return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
-                              assign_result, gt_flags)
+        sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+                                         assign_result, gt_flags)
+        return sampling_result
diff --git a/mmdet/core/bbox/samplers/random_sampler.py b/mmdet/core/bbox/samplers/random_sampler.py
index 0d02b2747fd73309c0d0840906727b350f1ea05e..3db00bab0ebb995c62ca0a89796017e1d2bf1db7 100644
--- a/mmdet/core/bbox/samplers/random_sampler.py
+++ b/mmdet/core/bbox/samplers/random_sampler.py
@@ -12,11 +12,12 @@ class RandomSampler(BaseSampler):
                  neg_pos_ub=-1,
                  add_gt_as_proposals=True,
                  **kwargs):
+        from mmdet.core.bbox import demodata
         super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
                                             add_gt_as_proposals)
+        self.rng = demodata.ensure_rng(kwargs.get('rng', None))
 
-    @staticmethod
-    def random_choice(gallery, num):
+    def random_choice(self, gallery, num):
         """Random select some elements from the gallery.
 
         It seems that Pytorch's implementation is slower than numpy so we use
@@ -26,7 +27,7 @@ class RandomSampler(BaseSampler):
         if isinstance(gallery, list):
             gallery = np.array(gallery)
         cands = np.arange(len(gallery))
-        np.random.shuffle(cands)
+        self.rng.shuffle(cands)
         rand_inds = cands[:num]
         if not isinstance(gallery, np.ndarray):
             rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
diff --git a/mmdet/core/bbox/samplers/sampling_result.py b/mmdet/core/bbox/samplers/sampling_result.py
index 696e65097109bb3e109860225cacc9e770db0b1f..dcf25eecd6727ae00fc4c1d2a0578f472dc90e0c 100644
--- a/mmdet/core/bbox/samplers/sampling_result.py
+++ b/mmdet/core/bbox/samplers/sampling_result.py
@@ -1,7 +1,25 @@
 import torch
 
+from mmdet.utils import util_mixins
 
-class SamplingResult(object):
+
+class SamplingResult(util_mixins.NiceRepr):
+    """
+    Example:
+        >>> # xdoctest: +IGNORE_WANT
+        >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA
+        >>> self = SamplingResult.random(rng=10)
+        >>> print('self = {}'.format(self))
+        self = <SamplingResult({
+            'neg_bboxes': torch.Size([12, 4]),
+            'neg_inds': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12]),
+            'num_gts': 4,
+            'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
+            'pos_bboxes': torch.Size([0, 4]),
+            'pos_inds': tensor([], dtype=torch.int64),
+            'pos_is_gt': tensor([], dtype=torch.uint8)
+        })>
+    """
 
     def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
                  gt_flags):
@@ -13,7 +31,17 @@ class SamplingResult(object):
 
         self.num_gts = gt_bboxes.shape[0]
         self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
-        self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
+
+        if gt_bboxes.numel() == 0:
+            # hack for index error case
+            assert self.pos_assigned_gt_inds.numel() == 0
+            self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
+        else:
+            if len(gt_bboxes.shape) < 2:
+                gt_bboxes = gt_bboxes.view(-1, 4)
+
+            self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
+
         if assign_result.labels is not None:
             self.pos_gt_labels = assign_result.labels[pos_inds]
         else:
@@ -22,3 +50,105 @@ class SamplingResult(object):
     @property
     def bboxes(self):
         return torch.cat([self.pos_bboxes, self.neg_bboxes])
+
+    def to(self, device):
+        """
+        Change the device of the data inplace.
+
+        Example:
+            >>> self = SamplingResult.random()
+            >>> print('self = {}'.format(self.to(None)))
+            >>> # xdoctest: +REQUIRES(--gpu)
+            >>> print('self = {}'.format(self.to(0)))
+        """
+        _dict = self.__dict__
+        for key, value in _dict.items():
+            if isinstance(value, torch.Tensor):
+                _dict[key] = value.to(device)
+        return self
+
+    def __nice__(self):
+        data = self.info.copy()
+        data['pos_bboxes'] = data.pop('pos_bboxes').shape
+        data['neg_bboxes'] = data.pop('neg_bboxes').shape
+        parts = ['\'{}\': {!r}'.format(k, v) for k, v in sorted(data.items())]
+        body = '    ' + ',\n    '.join(parts)
+        return '{\n' + body + '\n}'
+
+    @property
+    def info(self):
+        """
+        Returns a dictionary of info about the object
+        """
+        return {
+            'pos_inds': self.pos_inds,
+            'neg_inds': self.neg_inds,
+            'pos_bboxes': self.pos_bboxes,
+            'neg_bboxes': self.neg_bboxes,
+            'pos_is_gt': self.pos_is_gt,
+            'num_gts': self.num_gts,
+            'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
+        }
+
+    @classmethod
+    def random(cls, rng=None, **kwargs):
+        """
+        Args:
+            rng (None | int | numpy.random.RandomState): seed or state
+
+        Kwargs:
+            num_preds: number of predicted boxes
+            num_gts: number of true boxes
+            p_ignore (float): probability of a predicted box assinged to an
+                ignored truth
+            p_assigned (float): probability of a predicted box not being
+                assigned
+            p_use_label (float | bool): with labels or not
+
+        Returns:
+            AssignResult :
+
+        Example:
+            >>> from mmdet.core.bbox.samplers.sampling_result import *  # NOQA
+            >>> self = SamplingResult.random()
+            >>> print(self.__dict__)
+        """
+        from mmdet.core.bbox.samplers.random_sampler import RandomSampler
+        from mmdet.core.bbox.assigners.assign_result import AssignResult
+        from mmdet.core.bbox import demodata
+        rng = demodata.ensure_rng(rng)
+
+        # make probabalistic?
+        num = 32
+        pos_fraction = 0.5
+        neg_pos_ub = -1
+
+        assign_result = AssignResult.random(rng=rng, **kwargs)
+
+        # Note we could just compute an assignment
+        bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
+        gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
+
+        if rng.rand() > 0.2:
+            # sometimes algorithms squeeze their data, be robust to that
+            gt_bboxes = gt_bboxes.squeeze()
+            bboxes = bboxes.squeeze()
+
+        if assign_result.labels is None:
+            gt_labels = None
+        else:
+            gt_labels = None  # todo
+
+        if gt_labels is None:
+            add_gt_as_proposals = False
+        else:
+            add_gt_as_proposals = True  # make probabalistic?
+
+        sampler = RandomSampler(
+            num,
+            pos_fraction,
+            neg_pos_ubo=neg_pos_ub,
+            add_gt_as_proposals=add_gt_as_proposals,
+            rng=rng)
+        self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+        return self
diff --git a/mmdet/utils/util_mixins.py b/mmdet/utils/util_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..5585ac652736e36ac5d76a05e2aabb2aaea5786e
--- /dev/null
+++ b/mmdet/utils/util_mixins.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+"""
+This module defines the :class:`NiceRepr` mixin class, which defines a
+``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__``
+method, which you must define. This means you only have to overload one
+function instead of two.  Furthermore, if the object defines a ``__len__``
+method, then the ``__nice__`` method defaults to something sensible, otherwise
+it is treated as abstract and raises ``NotImplementedError``.
+
+To use simply have your object inherit from :class:`NiceRepr`
+(multi-inheritance should be ok).
+
+This code was copied from the ubelt library: https://github.com/Erotemic/ubelt
+
+Example:
+    >>> # Objects that define __nice__ have a default __str__ and __repr__
+    >>> class Student(NiceRepr):
+    ...    def __init__(self, name):
+    ...        self.name = name
+    ...    def __nice__(self):
+    ...        return self.name
+    >>> s1 = Student('Alice')
+    >>> s2 = Student('Bob')
+    >>> print('s1 = {}'.format(s1))
+    >>> print('s2 = {}'.format(s2))
+    s1 = <Student(Alice)>
+    s2 = <Student(Bob)>
+
+Example:
+    >>> # Objects that define __len__ have a default __nice__
+    >>> class Group(NiceRepr):
+    ...    def __init__(self, data):
+    ...        self.data = data
+    ...    def __len__(self):
+    ...        return len(self.data)
+    >>> g = Group([1, 2, 3])
+    >>> print('g = {}'.format(g))
+    g = <Group(3)>
+
+"""
+import warnings
+
+
+class NiceRepr(object):
+    """
+    Inherit from this class and define ``__nice__`` to "nicely" print your
+    objects.
+
+    Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
+    Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
+    If the inheriting class has a ``__len__``, method then the default
+    ``__nice__`` method will return its length.
+
+    Example:
+        >>> class Foo(NiceRepr):
+        ...    def __nice__(self):
+        ...        return 'info'
+        >>> foo = Foo()
+        >>> assert str(foo) == '<Foo(info)>'
+        >>> assert repr(foo).startswith('<Foo(info) at ')
+
+    Example:
+        >>> class Bar(NiceRepr):
+        ...    pass
+        >>> bar = Bar()
+        >>> import pytest
+        >>> with pytest.warns(None) as record:
+        >>>     assert 'object at' in str(bar)
+        >>>     assert 'object at' in repr(bar)
+
+    Example:
+        >>> class Baz(NiceRepr):
+        ...    def __len__(self):
+        ...        return 5
+        >>> baz = Baz()
+        >>> assert str(baz) == '<Baz(5)>'
+    """
+
+    def __nice__(self):
+        if hasattr(self, '__len__'):
+            # It is a common pattern for objects to use __len__ in __nice__
+            # As a convenience we define a default __nice__ for these objects
+            return str(len(self))
+        else:
+            # In all other cases force the subclass to overload __nice__
+            raise NotImplementedError(
+                'Define the __nice__ method for {!r}'.format(self.__class__))
+
+    def __repr__(self):
+        try:
+            nice = self.__nice__()
+            classname = self.__class__.__name__
+            return '<{0}({1}) at {2}>'.format(classname, nice, hex(id(self)))
+        except NotImplementedError as ex:
+            warnings.warn(str(ex), category=RuntimeWarning)
+            return object.__repr__(self)
+
+    def __str__(self):
+        try:
+            classname = self.__class__.__name__
+            nice = self.__nice__()
+            return '<{0}({1})>'.format(classname, nice)
+        except NotImplementedError as ex:
+            warnings.warn(str(ex), category=RuntimeWarning)
+            return object.__repr__(self)
diff --git a/tests/test_assigner.py b/tests/test_assigner.py
index 50cf7d530eece7a334b521fc52f647e10d0f74ee..5348eaba3a38047a74497539c60ddcad993be72a 100644
--- a/tests/test_assigner.py
+++ b/tests/test_assigner.py
@@ -259,3 +259,19 @@ def test_approx_iou_assigner_with_empty_boxes_and_gt():
     assign_result = self.assign(approxs, squares, approxs_per_octave,
                                 gt_bboxes)
     assert len(assign_result.gt_inds) == 0
+
+
+def test_random_assign_result():
+    """
+    Test random instantiation of assign result to catch corner cases
+    """
+    from mmdet.core.bbox.assigners.assign_result import AssignResult
+    AssignResult.random()
+
+    AssignResult.random(num_gts=0, num_preds=0)
+    AssignResult.random(num_gts=0, num_preds=3)
+    AssignResult.random(num_gts=3, num_preds=3)
+    AssignResult.random(num_gts=0, num_preds=3)
+    AssignResult.random(num_gts=7, num_preds=7)
+    AssignResult.random(num_gts=7, num_preds=64)
+    AssignResult.random(num_gts=24, num_preds=3)
diff --git a/tests/test_sampler.py b/tests/test_sampler.py
index c375d6e6f9a25fadf9f9de4f7c452a7dbeb90737..c75360268e6fd3afd6cbfb4ecf09a4de5cbf138a 100644
--- a/tests/test_sampler.py
+++ b/tests/test_sampler.py
@@ -233,3 +233,17 @@ def test_ohem_sampler_empty_pred():
 
     assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
     assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
+
+
+def test_random_sample_result():
+    from mmdet.core.bbox.samplers.sampling_result import SamplingResult
+    SamplingResult.random(num_gts=0, num_preds=0)
+    SamplingResult.random(num_gts=0, num_preds=3)
+    SamplingResult.random(num_gts=3, num_preds=3)
+    SamplingResult.random(num_gts=0, num_preds=3)
+    SamplingResult.random(num_gts=7, num_preds=7)
+    SamplingResult.random(num_gts=7, num_preds=64)
+    SamplingResult.random(num_gts=24, num_preds=3)
+
+    for i in range(3):
+        SamplingResult.random(rng=i)