diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py
index dee64f3671d5cedd9fcd19e08e6ec6102b39a8ce..800a1c292cfbc9fd80f5242d1131eeca81fd8577 100644
--- a/mmdet/core/bbox/samplers/ohem_sampler.py
+++ b/mmdet/core/bbox/samplers/ohem_sampler.py
@@ -15,8 +15,13 @@ class OHEMSampler(BaseSampler):
                  **kwargs):
         super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
                                           add_gt_as_proposals)
-        self.bbox_roi_extractor = context.bbox_roi_extractor
-        self.bbox_head = context.bbox_head
+        if not hasattr(context, 'num_stages'):
+            self.bbox_roi_extractor = context.bbox_roi_extractor
+            self.bbox_head = context.bbox_head
+        else:
+            self.bbox_roi_extractor = context.bbox_roi_extractor[
+                context.current_stage]
+            self.bbox_head = context.bbox_head[context.current_stage]
 
     def hard_mining(self, inds, num_expected, bboxes, labels, feats):
         with torch.no_grad():
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
index 6df9e22689d6a423a6e948623b50b779c11ad0b1..d466f633d27aeb52dcc8b78b5c68939f3691253e 100644
--- a/mmdet/models/detectors/cascade_rcnn.py
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -7,7 +7,7 @@ from .base import BaseDetector
 from .test_mixins import RPNTestMixin
 from .. import builder
 from ..registry import DETECTORS
-from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
+from mmdet.core import (build_assigner, bbox2roi, bbox2result, build_sampler,
                         merge_aug_masks)
 
 
@@ -131,17 +131,31 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
             proposal_list = proposals
 
         for i in range(self.num_stages):
+            self.current_stage = i
             rcnn_train_cfg = self.train_cfg.rcnn[i]
             lw = self.train_cfg.stage_loss_weights[i]
 
             # assign gts and sample proposals
-            assign_results, sampling_results = multi_apply(
-                assign_and_sample,
-                proposal_list,
-                gt_bboxes,
-                gt_bboxes_ignore,
-                gt_labels,
-                cfg=rcnn_train_cfg)
+            sampling_results = []
+            if self.with_bbox or self.with_mask:
+                bbox_assigner = build_assigner(rcnn_train_cfg.assigner)
+                bbox_sampler = build_sampler(
+                    rcnn_train_cfg.sampler, context=self)
+                num_imgs = img.size(0)
+                if gt_bboxes_ignore is None:
+                    gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+                for j in range(num_imgs):
+                    assign_result = bbox_assigner.assign(
+                        proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
+                        gt_labels[j])
+                    sampling_result = bbox_sampler.sample(
+                        assign_result,
+                        proposal_list[j],
+                        gt_bboxes[j],
+                        gt_labels[j],
+                        feats=[lvl_feat[j][None] for lvl_feat in x])
+                    sampling_results.append(sampling_result)
 
             # bbox head forward and loss
             bbox_roi_extractor = self.bbox_roi_extractor[i]