From af55b977619d6b93aa844a5152b12fe441a8b94d Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Fri, 23 Nov 2018 23:01:33 +0800 Subject: [PATCH] use dict to save multi-stage results --- configs/cascade_mask_rcnn_r50_fpn_1x.py | 2 +- configs/cascade_rcnn_r50_fpn_1x.py | 2 +- mmdet/models/detectors/cascade_rcnn.py | 39 +++++++++++++++---------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py index ccda54b..9f3f8b6 100644 --- a/configs/cascade_mask_rcnn_r50_fpn_1x.py +++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py @@ -142,7 +142,7 @@ train_cfg = dict( pos_weight=-1, debug=False) ], - loss_weight=[1, 0.5, 0.4]) + stage_loss_weights=[1, 0.5, 0.25]) test_cfg = dict( rpn=dict( nms_across_levels=False, diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py index 4b4fe16..5b4a70c 100644 --- a/configs/cascade_rcnn_r50_fpn_1x.py +++ b/configs/cascade_rcnn_r50_fpn_1x.py @@ -128,7 +128,7 @@ train_cfg = dict( pos_weight=-1, debug=False) ], - loss_weight=[1, 0.5, 0.4]) + stage_loss_weights=[1, 0.5, 0.25]) test_cfg = dict( rpn=dict( nms_across_levels=False, diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py index 3d4a712..16ad5fe 100644 --- a/mmdet/models/detectors/cascade_rcnn.py +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -1,3 +1,5 @@ +from __future__ import division + import torch import torch.nn as nn @@ -127,7 +129,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): for i in range(self.num_stages): rcnn_train_cfg = self.train_cfg.rcnn[i] - lw = self.train_cfg.loss_weight[i] + lw = self.train_cfg.stage_loss_weights[i] # assign gts and sample proposals assign_results, sampling_results = multi_apply( @@ -193,8 +195,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): scale_factor = img_meta[0]['scale_factor'] # "ms" in variable names means multi-stage - ms_bbox_result = [] - ms_segm_result = [] + ms_bbox_result = {} + ms_segm_result = {} ms_scores = [] rcnn_test_cfg = self.test_cfg.rcnn @@ -219,11 +221,11 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): nms_cfg=rcnn_test_cfg) bbox_result = bbox2result(det_bboxes, det_labels, bbox_head.num_classes) - ms_bbox_result.append(bbox_result) + ms_bbox_result['stage{}'.format(i)] = bbox_result if self.with_mask: - mask_block = self.mask_blocks[i] - mask_head = self.mask_heads[i] + mask_roi_extractor = self.mask_roi_extractor[i] + mask_head = self.mask_head[i] if det_bboxes.shape[0] == 0: segm_result = [ [] for _ in range(mask_head.num_classes - 1) @@ -232,20 +234,21 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): _bboxes = (det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) mask_rois = bbox2roi([_bboxes]) - mask_feats = mask_block( - x[:len(mask_block.featmap_strides)], mask_rois) + mask_feats = mask_roi_extractor( + x[:len(mask_roi_extractor.featmap_strides)], + mask_rois) mask_pred = mask_head(mask_feats) segm_result = mask_head.get_seg_masks( mask_pred, _bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale) - ms_segm_result.append(segm_result) + ms_segm_result['stage{}'.format(i)] = segm_result if i < self.num_stages - 1: bbox_label = cls_score.argmax(dim=1) rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, img_meta[0]) - cls_score = sum(ms_scores) / float(len(ms_scores)) + cls_score = sum(ms_scores) / len(ms_scores) det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( rois, cls_score, @@ -256,7 +259,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): nms_cfg=rcnn_test_cfg) bbox_result = bbox2result(det_bboxes, det_labels, self.bbox_head[-1].num_classes) - ms_bbox_result.append(bbox_result) + ms_bbox_result['ensemble'] = bbox_result if self.with_mask: if det_bboxes.shape[0] == 0: @@ -280,12 +283,12 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): segm_result = self.mask_head[-1].get_seg_masks( merged_masks, _bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale) - ms_segm_result.append(segm_result) + ms_segm_result['ensemble'] = segm_result if not self.test_cfg.keep_all_stages: - ms_bbox_result = ms_bbox_result[0] + ms_bbox_result = ms_bbox_result['ensemble'] if self.with_mask: - ms_segm_result = ms_segm_result[0] + ms_segm_result = ms_segm_result['ensemble'] if not self.with_mask: return ms_bbox_result @@ -301,5 +304,9 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ms_bbox_result, ms_segm_result = result else: ms_bbox_result = result - super(CascadeRCNN, self).show_result(data, ms_bbox_result[-1], - img_norm_cfg, **kwargs) + if isinstance(ms_bbox_result, dict): + bbox_result = ms_bbox_result['ensemble'] + else: + bbox_result = ms_bbox_result + super(CascadeRCNN, self).show_result(data, bbox_result, img_norm_cfg, + **kwargs) -- GitLab