diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py
index 61304307415392b17f78dd82f53c731dcca1a144..ec27c4c02c919af8103440966a69c723291ea427 100644
--- a/mmdet/core/anchor/anchor_generator.py
+++ b/mmdet/core/anchor/anchor_generator.py
@@ -2,6 +2,17 @@ import torch
 
 
 class AnchorGenerator(object):
+    """
+    Examples:
+        >>> from mmdet.core import AnchorGenerator
+        >>> self = AnchorGenerator(9, [1.], [1.])
+        >>> all_anchors = self.grid_anchors((2, 2), device='cpu')
+        >>> print(all_anchors)
+        tensor([[ 0.,  0.,  8.,  8.],
+                [16.,  0., 24.,  8.],
+                [ 0., 16.,  8., 24.],
+                [16., 16., 24., 24.]])
+    """
 
     def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
         self.base_size = base_size
diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py
index 580b9bdfb24d43e80600ad44a70cde6b8ccc58e3..b9d1e660546abfde71b1c8010ef2b5cb4c380957 100644
--- a/mmdet/core/bbox/transforms.py
+++ b/mmdet/core/bbox/transforms.py
@@ -37,6 +37,44 @@ def delta2bbox(rois,
                stds=[1, 1, 1, 1],
                max_shape=None,
                wh_ratio_clip=16 / 1000):
+    """
+    Apply deltas to shift/scale base boxes.
+
+    Typically the rois are anchor or proposed bounding boxes and the deltas are
+    network outputs used to shift/scale those boxes.
+
+    Args:
+        rois (Tensor): boxes to be transformed. Has shape (N, 4)
+        deltas (Tensor): encoded offsets with respect to each roi.
+            Has shape (N, 4). Note N = num_anchors * W * H when rois is a grid
+            of anchors. Offset encoding follows [1]_.
+        means (list): denormalizing means for delta coordinates
+        stds (list): denormalizing standard deviation for delta coordinates
+        max_shape (tuple[int, int]): maximum bounds for boxes. specifies (H, W)
+        wh_ratio_clip (float): maximum aspect ratio for boxes.
+
+    Returns:
+        Tensor: boxes with shape (N, 4), where columns represent
+            tl_x, tl_y, br_x, br_y.
+
+    References:
+        .. [1] https://arxiv.org/abs/1311.2524
+
+    Example:
+        >>> rois = torch.Tensor([[ 0.,  0.,  1.,  1.],
+        >>>                      [ 0.,  0.,  1.,  1.],
+        >>>                      [ 0.,  0.,  1.,  1.],
+        >>>                      [ 5.,  5.,  5.,  5.]])
+        >>> deltas = torch.Tensor([[  0.,   0.,   0.,   0.],
+        >>>                        [  1.,   1.,   1.,   1.],
+        >>>                        [  0.,   0.,   2.,  -1.],
+        >>>                        [ 0.7, -1.9, -0.5,  0.3]])
+        >>> delta2bbox(rois, deltas, max_shape=(32, 32))
+        tensor([[0.0000, 0.0000, 1.0000, 1.0000],
+                [0.2817, 0.2817, 4.7183, 4.7183],
+                [0.0000, 0.6321, 7.3891, 0.3679],
+                [5.8967, 2.9251, 5.5033, 3.2749]])
+    """
     means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
     stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
     denorm_deltas = deltas * stds + means
@@ -47,14 +85,19 @@ def delta2bbox(rois,
     max_ratio = np.abs(np.log(wh_ratio_clip))
     dw = dw.clamp(min=-max_ratio, max=max_ratio)
     dh = dh.clamp(min=-max_ratio, max=max_ratio)
+    # Compute center of each roi
     px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
     py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
+    # Compute width/height of each roi
     pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
     ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
+    # Use exp(network energy) to enlarge/shrink each roi
     gw = pw * dw.exp()
     gh = ph * dh.exp()
+    # Use network energy to shift the center of each roi
     gx = torch.addcmul(px, 1, pw, dx)  # gx = px + pw * dx
     gy = torch.addcmul(py, 1, ph, dy)  # gy = py + ph * dy
+    # Convert center-xy/width/height to top-left, bottom-right
     x1 = gx - gw * 0.5 + 0.5
     y1 = gy - gh * 0.5 + 0.5
     x2 = gx + gw * 0.5 - 0.5
diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py
index 01beecd43abc7641d43773b192f988a55f4295d9..ce3794c6450347d9e0b469066584ccfcdef6f8d0 100644
--- a/mmdet/core/post_processing/bbox_nms.py
+++ b/mmdet/core/post_processing/bbox_nms.py
@@ -13,7 +13,8 @@ def multiclass_nms(multi_bboxes,
 
     Args:
         multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
-        multi_scores (Tensor): shape (n, #class)
+        multi_scores (Tensor): shape (n, #class), where the 0th column
+            contains scores of the background class, but this will be ignored.
         score_thr (float): bbox threshold, bboxes with scores lower than it
             will not be considered.
         nms_thr (float): NMS IoU threshold
diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index 55c25de9e11c027dd6be213cf4512b7acf1a67ee..d526b3b3beb37b3177c2bb25f3bfcaaa7fe887d4 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -16,6 +16,8 @@ class AnchorHead(nn.Module):
     """Anchor-based head (RPN, RetinaNet, SSD, etc.).
 
     Args:
+        num_classes (int): Number of categories including the background
+            category.
         in_channels (int): Number of channels in the input feature map.
         feat_channels (int): Number of channels of the feature map.
         anchor_scales (Iterable): Anchor scales.
@@ -45,6 +47,7 @@ class AnchorHead(nn.Module):
                  loss_bbox=dict(
                      type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
         super(AnchorHead, self).__init__()
+        # NOTE: in_channels is only used in child classes (e.g. RetinaHead)
         self.in_channels = in_channels
         self.num_classes = num_classes
         self.feat_channels = feat_channels
@@ -62,6 +65,10 @@ class AnchorHead(nn.Module):
             self.cls_out_channels = num_classes - 1
         else:
             self.cls_out_channels = num_classes
+
+        if self.cls_out_channels <= 0:
+            raise ValueError('num_classes={} is too small'.format(num_classes))
+
         self.loss_cls = build_loss(loss_cls)
         self.loss_bbox = build_loss(loss_bbox)
         self.fp16_enabled = False
@@ -202,6 +209,46 @@ class AnchorHead(nn.Module):
     @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
     def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
                    rescale=False):
+        """
+        Transform network output for a batch into labeled boxes.
+
+        Args:
+            cls_scores (list[Tensor]): Box scores for each scale level
+                Has shape (N, num_anchors * num_classes, H, W)
+            bbox_preds (list[Tensor]): Box energies / deltas for each scale
+                level with shape (N, num_anchors * 4, H, W)
+            img_metas (list[dict]): size / scale info for each image
+            cfg (mmcv.Config): test / postprocessing configuration
+            rescale (bool): if True, return boxes in original image space
+
+        Returns:
+            list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple.
+                The first item is an (n, 5) tensor, where the first 4 columns
+                are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+                5-th column is a score between 0 and 1. The second item is a
+                (n,) tensor where each item is the class index of the
+                corresponding box.
+
+        Example:
+            >>> import mmcv
+            >>> self = AnchorHead(num_classes=9, in_channels=1,
+            >>>                   feat_channels=1)
+            >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
+            >>> cfg = mmcv.Config(dict(
+            >>>     score_thr=0.00,
+            >>>     nms=dict(type='nms', iou_thr=1.0),
+            >>>     max_per_img=10))
+            >>> feat = torch.rand(1, 1, 3, 3)
+            >>> cls_score, bbox_pred = self.forward_single(feat)
+            >>> # note the input lists are over different levels, not images
+            >>> cls_scores, bbox_preds = [cls_score], [bbox_pred]
+            >>> result_list = self.get_bboxes(cls_scores, bbox_preds,
+            >>>                               img_metas, cfg)
+            >>> det_bboxes, det_labels = result_list[0]
+            >>> assert len(result_list) == 1
+            >>> assert det_bboxes.shape[1] == 5
+            >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
+        """
         assert len(cls_scores) == len(bbox_preds)
         num_levels = len(cls_scores)
 
@@ -229,18 +276,21 @@ class AnchorHead(nn.Module):
         return result_list
 
     def get_bboxes_single(self,
-                          cls_scores,
-                          bbox_preds,
+                          cls_score_list,
+                          bbox_pred_list,
                           mlvl_anchors,
                           img_shape,
                           scale_factor,
                           cfg,
                           rescale=False):
-        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
+        """
+        Transform outputs for a single batch item into labeled boxes.
+        """
+        assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
         mlvl_bboxes = []
         mlvl_scores = []
-        for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
-                                                 mlvl_anchors):
+        for cls_score, bbox_pred, anchors in zip(cls_score_list,
+                                                 bbox_pred_list, mlvl_anchors):
             assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
             cls_score = cls_score.permute(1, 2,
                                           0).reshape(-1, self.cls_out_channels)
@@ -251,6 +301,7 @@ class AnchorHead(nn.Module):
             bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
             nms_pre = cfg.get('nms_pre', -1)
             if nms_pre > 0 and scores.shape[0] > nms_pre:
+                # Get maximum scores for foreground classes.
                 if self.use_sigmoid_cls:
                     max_scores, _ = scores.max(dim=1)
                 else:
@@ -268,6 +319,7 @@ class AnchorHead(nn.Module):
             mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
         mlvl_scores = torch.cat(mlvl_scores)
         if self.use_sigmoid_cls:
+            # Add a dummy background class to the front when using sigmoid
             padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
             mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
         det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,