From ee7e679afa80bf6a7ca2a635b81f2fe039288361 Mon Sep 17 00:00:00 2001 From: Mordekaiser <2601882982@qq.com> Date: Thu, 4 Apr 2019 05:42:51 +0800 Subject: [PATCH] fix OHEM with cascade rcnn (#373) * fix OHEM with cascade_rcnn * fix OHEM with cascade_rcnn * delete space * delete white space * delete unused lib * Delete cascade_rcnn_ohem_r101_fpn_1x.py * fix unreasonable code * fix Single quote * fix code style * fix code style * fix file permission --- mmdet/core/bbox/samplers/ohem_sampler.py | 9 +++++-- mmdet/models/detectors/cascade_rcnn.py | 30 +++++++++++++++++------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py index dee64f3..800a1c2 100644 --- a/mmdet/core/bbox/samplers/ohem_sampler.py +++ b/mmdet/core/bbox/samplers/ohem_sampler.py @@ -15,8 +15,13 @@ class OHEMSampler(BaseSampler): **kwargs): super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals) - self.bbox_roi_extractor = context.bbox_roi_extractor - self.bbox_head = context.bbox_head + if not hasattr(context, 'num_stages'): + self.bbox_roi_extractor = context.bbox_roi_extractor + self.bbox_head = context.bbox_head + else: + self.bbox_roi_extractor = context.bbox_roi_extractor[ + context.current_stage] + self.bbox_head = context.bbox_head[context.current_stage] def hard_mining(self, inds, num_expected, bboxes, labels, feats): with torch.no_grad(): diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 6df9e22..d466f63 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -7,7 +7,7 @@ from .base import BaseDetector from .test_mixins import RPNTestMixin from .. import builder from ..registry import DETECTORS -from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, +from mmdet.core import (build_assigner, bbox2roi, bbox2result, build_sampler, merge_aug_masks) @@ -131,17 +131,31 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): proposal_list = proposals for i in range(self.num_stages): + self.current_stage = i rcnn_train_cfg = self.train_cfg.rcnn[i] lw = self.train_cfg.stage_loss_weights[i] # assign gts and sample proposals - assign_results, sampling_results = multi_apply( - assign_and_sample, - proposal_list, - gt_bboxes, - gt_bboxes_ignore, - gt_labels, - cfg=rcnn_train_cfg) + sampling_results = [] + if self.with_bbox or self.with_mask: + bbox_assigner = build_assigner(rcnn_train_cfg.assigner) + bbox_sampler = build_sampler( + rcnn_train_cfg.sampler, context=self) + num_imgs = img.size(0) + if gt_bboxes_ignore is None: + gt_bboxes_ignore = [None for _ in range(num_imgs)] + + for j in range(num_imgs): + assign_result = bbox_assigner.assign( + proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j], + gt_labels[j]) + sampling_result = bbox_sampler.sample( + assign_result, + proposal_list[j], + gt_bboxes[j], + gt_labels[j], + feats=[lvl_feat[j][None] for lvl_feat in x]) + sampling_results.append(sampling_result) # bbox head forward and loss bbox_roi_extractor = self.bbox_roi_extractor[i] -- GitLab