From d7743255f24d4be55a6eacc56103063d9046fe12 Mon Sep 17 00:00:00 2001 From: pangjm <pjmzju@gmail.com> Date: Wed, 10 Oct 2018 14:47:08 +0800 Subject: [PATCH] revise fast test & fix aug test bug --- mmdet/models/detectors/fast_rcnn.py | 39 ++++++++++++++++++++------- mmdet/models/detectors/test_mixins.py | 9 +++++-- mmdet/models/detectors/two_stage.py | 2 +- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py index 0dbf17a..fd80a87 100644 --- a/mmdet/models/detectors/fast_rcnn.py +++ b/mmdet/models/detectors/fast_rcnn.py @@ -14,12 +14,33 @@ class FastRCNN(TwoStageDetector): mask_head=None, pretrained=None): super(FastRCNN, self).__init__( - backbone=backbone, - neck=neck, - bbox_roi_extractor=bbox_roi_extractor, - bbox_head=bbox_head, - train_cfg=train_cfg, - test_cfg=test_cfg, - mask_roi_extractor=mask_roi_extractor, - mask_head=mask_head, - pretrained=pretrained) + backbone=backbone, + neck=neck, + bbox_roi_extractor=bbox_roi_extractor, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + mask_roi_extractor=mask_roi_extractor, + mask_head=mask_head, + pretrained=pretrained) + + def forward_test(self, imgs, img_metas, proposals, **kwargs): + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError('{} must be a list, but got {}'.format( + name, type(var))) + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError( + 'num of augmentations ({}) != num of image meta ({})'.format( + len(imgs), len(img_metas))) + # TODO: remove the restriction of imgs_per_gpu == 1 when prepared + imgs_per_gpu = imgs[0].size(0) + assert imgs_per_gpu == 1 + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], proposals[0], + **kwargs) + else: + return self.aug_test(imgs, img_metas, proposals, **kwargs) diff --git a/mmdet/models/detectors/test_mixins.py b/mmdet/models/detectors/test_mixins.py index 77ba244..38136f4 100644 --- a/mmdet/models/detectors/test_mixins.py +++ b/mmdet/models/detectors/test_mixins.py @@ -135,6 +135,11 @@ class MaskTestMixin(object): ori_shape = img_metas[0][0]['ori_shape'] segm_result = self.mask_head.get_seg_masks( - merged_masks, det_bboxes, det_labels, self.test_cfg.rcnn, - ori_shape) + merged_masks, + det_bboxes, + det_labels, + self.test_cfg.rcnn, + ori_shape, + scale_factor=1.0, + rescale=False) return segm_result diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py index 3cd6838..b2f2839 100644 --- a/mmdet/models/detectors/two_stage.py +++ b/mmdet/models/detectors/two_stage.py @@ -147,7 +147,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, proposal_list = self.simple_test_rpn( x, img_meta, - self.test_cfg.rpn) if proposals is None else proposals[0] + self.test_cfg.rpn) if proposals is None else proposals det_bboxes, det_labels = self.simple_test_bboxes( x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) -- GitLab