diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py index dee64f3671d5cedd9fcd19e08e6ec6102b39a8ce..800a1c292cfbc9fd80f5242d1131eeca81fd8577 100644 --- a/mmdet/core/bbox/samplers/ohem_sampler.py +++ b/mmdet/core/bbox/samplers/ohem_sampler.py @@ -15,8 +15,13 @@ class OHEMSampler(BaseSampler): **kwargs): super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals) - self.bbox_roi_extractor = context.bbox_roi_extractor - self.bbox_head = context.bbox_head + if not hasattr(context, 'num_stages'): + self.bbox_roi_extractor = context.bbox_roi_extractor + self.bbox_head = context.bbox_head + else: + self.bbox_roi_extractor = context.bbox_roi_extractor[ + context.current_stage] + self.bbox_head = context.bbox_head[context.current_stage] def hard_mining(self, inds, num_expected, bboxes, labels, feats): with torch.no_grad(): diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 6df9e22689d6a423a6e948623b50b779c11ad0b1..d466f633d27aeb52dcc8b78b5c68939f3691253e 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -7,7 +7,7 @@ from .base import BaseDetector from .test_mixins import RPNTestMixin from .. import builder from ..registry import DETECTORS -from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, +from mmdet.core import (build_assigner, bbox2roi, bbox2result, build_sampler, merge_aug_masks) @@ -131,17 +131,31 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): proposal_list = proposals for i in range(self.num_stages): + self.current_stage = i rcnn_train_cfg = self.train_cfg.rcnn[i] lw = self.train_cfg.stage_loss_weights[i] # assign gts and sample proposals - assign_results, sampling_results = multi_apply( - assign_and_sample, - proposal_list, - gt_bboxes, - gt_bboxes_ignore, - gt_labels, - cfg=rcnn_train_cfg) + sampling_results = [] + if self.with_bbox or self.with_mask: + bbox_assigner = build_assigner(rcnn_train_cfg.assigner) + bbox_sampler = build_sampler( + rcnn_train_cfg.sampler, context=self) + num_imgs = img.size(0) + if gt_bboxes_ignore is None: + gt_bboxes_ignore = [None for _ in range(num_imgs)] + + for j in range(num_imgs): + assign_result = bbox_assigner.assign( + proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j], + gt_labels[j]) + sampling_result = bbox_sampler.sample( + assign_result, + proposal_list[j], + gt_bboxes[j], + gt_labels[j], + feats=[lvl_feat[j][None] for lvl_feat in x]) + sampling_results.append(sampling_result) # bbox head forward and loss bbox_roi_extractor = self.bbox_roi_extractor[i]