diff --git a/mmdet/core/bbox_ops/__init__.py b/mmdet/core/bbox_ops/__init__.py
index 22163f75ef5484a48ed223e417f83537522532a7..a5c21dce52f25781e2e4e3e760a837d4d36eec5c 100644
--- a/mmdet/core/bbox_ops/__init__.py
+++ b/mmdet/core/bbox_ops/__init__.py
@@ -1,14 +1,15 @@
 from .geometry import bbox_overlaps
 from .sampling import (random_choice, bbox_assign, bbox_assign_wrt_overlaps,
-                       bbox_sampling, sample_positives, sample_negatives)
+                       bbox_sampling, bbox_sampling_pos, bbox_sampling_neg,
+                       sample_bboxes)
 from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping,
                          bbox_mapping_back, bbox2roi, roi2bbox, bbox2result)
 from .bbox_target import bbox_target
 
 __all__ = [
     'bbox_overlaps', 'random_choice', 'bbox_assign',
-    'bbox_assign_wrt_overlaps', 'bbox_sampling', 'sample_positives',
-    'sample_negatives', 'bbox2delta', 'delta2bbox', 'bbox_flip',
-    'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result',
-    'bbox_target'
+    'bbox_assign_wrt_overlaps', 'bbox_sampling', 'bbox_sampling_pos',
+    'bbox_sampling_neg', 'sample_bboxes', 'bbox2delta', 'delta2bbox',
+    'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox',
+    'bbox2result', 'bbox_target'
 ]
diff --git a/mmdet/core/bbox_ops/sampling.py b/mmdet/core/bbox_ops/sampling.py
index 28043182acf41583373a86729bf10f309f384e8a..80f8c8207cc55b37d647bef33f0486d0a49ccd4a 100644
--- a/mmdet/core/bbox_ops/sampling.py
+++ b/mmdet/core/bbox_ops/sampling.py
@@ -78,27 +78,32 @@ def bbox_assign_wrt_overlaps(overlaps,
                              pos_iou_thr=0.5,
                              neg_iou_thr=0.5,
                              min_pos_iou=.0):
-    """Assign a corresponding gt bbox or background to each proposal/anchor
-    This function assign a gt bbox to every proposal, each proposals will be
+    """Assign a corresponding gt bbox or background to each proposal/anchor.
+
+    This method assign a gt bbox to every proposal, each proposals will be
     assigned with -1, 0, or a positive number. -1 means don't care, 0 means
     negative sample, positive number is the index (1-based) of assigned gt.
     The assignment is done in following steps, the order matters:
+
     1. assign every anchor to -1
     2. assign proposals whose iou with all gts < neg_iou_thr to 0
     3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
     assign it to that bbox
     4. for each gt bbox, assign its nearest proposals(may be more than one)
     to itself
+
     Args:
-        overlaps(Tensor): overlaps between n proposals and k gt_bboxes, shape(n, k)
-        gt_labels(Tensor, optional): shape (k, )
-        pos_iou_thr(float): iou threshold for positive bboxes
-        neg_iou_thr(float or tuple): iou threshold for negative bboxes
-        min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox,
-                            for RPN, it is usually set as 0, for Fast R-CNN,
-                            it is usually set as pos_iou_thr
+        overlaps (Tensor): Overlaps between n proposals and k gt_bboxes,
+            shape(n, k).
+        gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
+        pos_iou_thr (float): IoU threshold for positive bboxes.
+        neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+        min_pos_iou (float): Minimum IoU for a bbox to be considered as a
+            positive bbox. This argument only affects the 4th step.
+
     Returns:
-        tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
+        tuple: (assigned_gt_inds, [assigned_labels], argmax_overlaps,
+            max_overlaps), shape (n, )
     """
     num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
     # 1. assign -1 by default
@@ -144,8 +149,9 @@ def bbox_assign_wrt_overlaps(overlaps,
         return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps
 
 
-def sample_positives(assigned_gt_inds, num_expected, balance_sampling=True):
-    """Balance sampling for positive bboxes/anchors
+def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
+    """Balance sampling for positive bboxes/anchors.
+
     1. calculate average positive num for each gt: num_per_gt
     2. sample at most num_per_gt positives for each gt
     3. random sampling from rest anchors if not enough fg
@@ -186,15 +192,16 @@ def sample_positives(assigned_gt_inds, num_expected, balance_sampling=True):
         return sampled_inds
 
 
-def sample_negatives(assigned_gt_inds,
-                     num_expected,
-                     max_overlaps=None,
-                     balance_thr=0,
-                     hard_fraction=0.5):
-    """Balance sampling for negative bboxes/anchors
-    negative samples are split into 2 set: hard(balance_thr <= iou < neg_iou_thr)
-    and easy(iou < balance_thr), around equal number of bg are sampled
-    from each set.
+def bbox_sampling_neg(assigned_gt_inds,
+                      num_expected,
+                      max_overlaps=None,
+                      balance_thr=0,
+                      hard_fraction=0.5):
+    """Balance sampling for negative bboxes/anchors.
+
+    Negative samples are split into 2 set: hard (balance_thr <= iou <
+    neg_iou_thr) and easy(iou < balance_thr). The sampling ratio is controlled
+    by `hard_fraction`.
     """
     neg_inds = torch.nonzero(assigned_gt_inds == 0)
     if neg_inds.numel() != 0:
@@ -247,17 +254,87 @@ def bbox_sampling(assigned_gt_inds,
                   max_overlaps=None,
                   neg_balance_thr=0,
                   neg_hard_fraction=0.5):
+    """Sample positive and negative bboxes given assigned results.
+
+    Args:
+        assigned_gt_inds (Tensor): Assigned gt indices for each bbox.
+        num_expected (int): Expected total samples (pos and neg).
+        pos_fraction (float): Positive sample fraction.
+        neg_pos_ub (float): Negative/Positive upper bound.
+        pos_balance_sampling(bool): Whether to sample positive samples around
+            each gt bbox evenly.
+        max_overlaps (Tensor, optional): For each bbox, the max IoU of all gts.
+            Used for negative balance sampling only.
+        neg_balance_thr (float, optional): IoU threshold for simple/hard
+            negative balance sampling.
+        neg_hard_fraction (float, optional): Fraction of hard negative samples
+            for negative balance sampling.
+
+    Returns:
+        tuple[Tensor]: positive bbox indices, negative bbox indices.
+    """
     num_expected_pos = int(num_expected * pos_fraction)
-    pos_inds = sample_positives(assigned_gt_inds, num_expected_pos,
-                                pos_balance_sampling)
+    pos_inds = bbox_sampling_pos(assigned_gt_inds, num_expected_pos,
+                                 pos_balance_sampling)
+    # We found that sampled indices have duplicated items occasionally.
+    # (mab be a bug of PyTorch)
     pos_inds = pos_inds.unique()
     num_sampled_pos = pos_inds.numel()
     num_neg_max = int(
         neg_pos_ub *
         num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub)
     num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos)
-    neg_inds = sample_negatives(assigned_gt_inds, num_expected_neg,
-                                max_overlaps, neg_balance_thr,
-                                neg_hard_fraction)
+    neg_inds = bbox_sampling_neg(assigned_gt_inds, num_expected_neg,
+                                 max_overlaps, neg_balance_thr,
+                                 neg_hard_fraction)
     neg_inds = neg_inds.unique()
     return pos_inds, neg_inds
+
+
+def sample_bboxes(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
+    """Sample positive and negative bboxes.
+
+    This is a simple implementation of bbox sampling given candidates and
+    ground truth bboxes, which includes 3 steps.
+
+    1. Assign gt to each bbox.
+    2. Add gt bboxes to the sampling pool (optional).
+    3. Perform positive and negative sampling.
+
+    Args:
+        bboxes (Tensor): Boxes to be sampled from.
+        gt_bboxes (Tensor): Ground truth bboxes.
+        gt_bboxes_ignore (Tensor): Ignored ground truth bboxes. In MS COCO,
+            `crowd` bboxes are considered as ignored.
+        gt_labels (Tensor): Class labels of ground truth bboxes.
+        cfg (dict): Sampling configs.
+
+    Returns:
+        tuple[Tensor]: pos_bboxes, neg_bboxes, pos_assigned_gt_inds,
+            pos_gt_bboxes, pos_gt_labels
+    """
+    bboxes = bboxes[:, :4]
+    assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
+        bbox_assign(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels,
+                    cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou,
+                    cfg.crowd_thr)
+
+    if cfg.add_gt_as_proposals:
+        bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+        gt_assign_self = torch.arange(
+            1, len(gt_labels) + 1, dtype=torch.long, device=bboxes.device)
+        assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
+        assigned_labels = torch.cat([gt_labels, assigned_labels])
+
+    pos_inds, neg_inds = bbox_sampling(
+        assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub,
+        cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr)
+
+    pos_bboxes = bboxes[pos_inds]
+    neg_bboxes = bboxes[neg_inds]
+    pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
+    pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
+    pos_gt_labels = assigned_labels[pos_inds]
+
+    return (pos_bboxes, neg_bboxes, pos_assigned_gt_inds, pos_gt_bboxes,
+            pos_gt_labels)
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
index f69db22ced983e2b41eb693daa5ec1099f6f4a55..8573d83215f120ba392a2f6b45cb9b6b93ca0519 100644
--- a/mmdet/models/detectors/two_stage.py
+++ b/mmdet/models/detectors/two_stage.py
@@ -4,7 +4,7 @@ import torch.nn as nn
 from .base import BaseDetector
 from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
 from .. import builder
-from mmdet.core import bbox2roi, bbox2result, multi_apply
+from mmdet.core import sample_bboxes, bbox2roi, bbox2result, multi_apply
 
 
 class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
@@ -97,13 +97,14 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
             proposal_list = proposals
 
         if self.with_bbox:
-            rcnn_train_cfg_list = [
-                self.train_cfg.rcnn for _ in range(len(proposal_list))
-            ]
             (pos_proposals, neg_proposals, pos_assigned_gt_inds, pos_gt_bboxes,
              pos_gt_labels) = multi_apply(
-                 self.bbox_roi_extractor.sample_proposals, proposal_list,
-                 gt_bboxes, gt_bboxes_ignore, gt_labels, rcnn_train_cfg_list)
+                 sample_bboxes,
+                 proposal_list,
+                 gt_bboxes,
+                 gt_bboxes_ignore,
+                 gt_labels,
+                 cfg=self.train_cfg.rcnn)
             (labels, label_weights, bbox_targets,
              bbox_weights) = self.bbox_head.get_bbox_target(
                  pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
diff --git a/mmdet/models/roi_extractors/__init__.py b/mmdet/models/roi_extractors/__init__.py
index e76e689753f10e87b3f6d9482e880b902f9b747e..9161708ce13fa4f0a6bb188e82a19a163b9b7e4f 100644
--- a/mmdet/models/roi_extractors/__init__.py
+++ b/mmdet/models/roi_extractors/__init__.py
@@ -1,3 +1,3 @@
-from .single_level import SingleLevelRoI
+from .single_level import SingleRoIExtractor
 
-__all__ = ['SingleLevelRoI']
+__all__ = ['SingleRoIExtractor']
diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py
index 6aa29e598e58696634d7934ecc00bb7105084d62..3f97a631f987104422f65110a2cb6b49e080de0e 100644
--- a/mmdet/models/roi_extractors/single_level.py
+++ b/mmdet/models/roi_extractors/single_level.py
@@ -4,19 +4,27 @@ import torch
 import torch.nn as nn
 
 from mmdet import ops
-from mmdet.core import bbox_assign, bbox_sampling
 
 
-class SingleLevelRoI(nn.Module):
-    """Extract RoI features from a single level feature map. Each RoI is
-    mapped to a level according to its scale."""
+class SingleRoIExtractor(nn.Module):
+    """Extract RoI features from a single level feature map.
+
+    If there are mulitple input feature levels, each RoI is mapped to a level
+    according to its scale.
+
+    Args:
+        roi_layer (dict): Specify RoI layer type and arguments.
+        out_channels (int): Output channels of RoI layers.
+        featmap_strides (int): Strides of input feature maps.
+        finest_scale (int): Scale threshold of mapping to level 0.
+    """
 
     def __init__(self,
                  roi_layer,
                  out_channels,
                  featmap_strides,
                  finest_scale=56):
-        super(SingleLevelRoI, self).__init__()
+        super(SingleRoIExtractor, self).__init__()
         self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
         self.out_channels = out_channels
         self.featmap_strides = featmap_strides
@@ -24,6 +32,7 @@ class SingleLevelRoI(nn.Module):
 
     @property
     def num_inputs(self):
+        """int: Input feature map levels."""
         return len(self.featmap_strides)
 
     def init_weights(self):
@@ -39,12 +48,19 @@ class SingleLevelRoI(nn.Module):
         return roi_layers
 
     def map_roi_levels(self, rois, num_levels):
-        """Map rois to corresponding feature levels (0-based) by scales.
+        """Map rois to corresponding feature levels by scales.
 
         - scale < finest_scale: level 0
         - finest_scale <= scale < finest_scale * 2: level 1
         - finest_scale * 2 <= scale < finest_scale * 4: level 2
         - scale >= finest_scale * 4: level 3
+
+        Args:
+            rois (Tensor): Input RoIs, shape (k, 5).
+            num_levels (int): Total level number.
+
+        Returns:
+            Tensor: Level index (0-based) of each RoI, shape (k, )
         """
         scale = torch.sqrt(
             (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1))
@@ -52,43 +68,7 @@ class SingleLevelRoI(nn.Module):
         target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
         return target_lvls
 
-    def sample_proposals(self, proposals, gt_bboxes, gt_bboxes_ignore,
-                         gt_labels, cfg):
-        proposals = proposals[:, :4]
-        assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
-            bbox_assign(proposals, gt_bboxes, gt_bboxes_ignore, gt_labels,
-                        cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou,
-                        cfg.crowd_thr)
-
-        if cfg.add_gt_as_proposals:
-            proposals = torch.cat([gt_bboxes, proposals], dim=0)
-            gt_assign_self = torch.arange(
-                1,
-                len(gt_labels) + 1,
-                dtype=torch.long,
-                device=proposals.device)
-            assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
-            assigned_labels = torch.cat([gt_labels, assigned_labels])
-
-        pos_inds, neg_inds = bbox_sampling(
-            assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction,
-            cfg.neg_pos_ub, cfg.pos_balance_sampling, max_overlaps,
-            cfg.neg_balance_thr)
-
-        pos_proposals = proposals[pos_inds]
-        neg_proposals = proposals[neg_inds]
-        pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
-        pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
-        pos_gt_labels = assigned_labels[pos_inds]
-
-        return (pos_proposals, neg_proposals, pos_assigned_gt_inds,
-                pos_gt_bboxes, pos_gt_labels)
-
     def forward(self, feats, rois):
-        """Extract roi features with the roi layer. If multiple feature levels
-        are used, then rois are mapped to corresponding levels according to
-        their scales.
-        """
         if len(feats) == 1:
             return self.roi_layers[0](feats[0], rois)
 
diff --git a/tools/configs/r50_fpn_frcnn_1x.py b/tools/configs/r50_fpn_frcnn_1x.py
index 09167dd09b18a5946766afaf1d5a96277fd6ceae..044c654ffa26f0e72e79eb998187f9db52b95cc6 100644
--- a/tools/configs/r50_fpn_frcnn_1x.py
+++ b/tools/configs/r50_fpn_frcnn_1x.py
@@ -25,7 +25,7 @@ model = dict(
         target_stds=[1.0, 1.0, 1.0, 1.0],
         use_sigmoid_cls=True),
     bbox_roi_extractor=dict(
-        type='SingleLevelRoI',
+        type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
         out_channels=256,
         featmap_strides=[4, 8, 16, 32]),
diff --git a/tools/configs/r50_fpn_maskrcnn_1x.py b/tools/configs/r50_fpn_maskrcnn_1x.py
index 35dab5633c95b855f4d32a8245aaf44c69cddced..881a7498f624be8efab2f1387594e286cf15b3ef 100644
--- a/tools/configs/r50_fpn_maskrcnn_1x.py
+++ b/tools/configs/r50_fpn_maskrcnn_1x.py
@@ -25,7 +25,7 @@ model = dict(
         target_stds=[1.0, 1.0, 1.0, 1.0],
         use_sigmoid_cls=True),
     bbox_roi_extractor=dict(
-        type='SingleLevelRoI',
+        type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
         out_channels=256,
         featmap_strides=[4, 8, 16, 32]),
@@ -40,7 +40,7 @@ model = dict(
         target_stds=[0.1, 0.1, 0.2, 0.2],
         reg_class_agnostic=False),
     mask_roi_extractor=dict(
-        type='SingleLevelRoI',
+        type='SingleRoIExtractor',
         roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
         out_channels=256,
         featmap_strides=[4, 8, 16, 32]),