From c10ae5d2ca158ae506b93198afa4926ea48bb40b Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Sun, 21 Oct 2018 20:21:35 +0800 Subject: [PATCH] add cascade rcnn --- configs/cascade_mask_rcnn_r50_fpn_1x.py | 224 ++++++++++++++++++ configs/cascade_rcnn_r50_fpn_1x.py | 209 +++++++++++++++++ mmdet/models/bbox_heads/bbox_head.py | 71 ++++++ mmdet/models/detectors/__init__.py | 3 +- mmdet/models/detectors/cascade_rcnn.py | 291 ++++++++++++++++++++++++ 5 files changed, 797 insertions(+), 1 deletion(-) create mode 100644 configs/cascade_mask_rcnn_r50_fpn_1x.py create mode 100644 configs/cascade_rcnn_r50_fpn_1x.py create mode 100644 mmdet/models/detectors/cascade_rcnn.py diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py new file mode 100644 index 0000000..ccda54b --- /dev/null +++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py @@ -0,0 +1,224 @@ +# model settings +model = dict( + type='CascadeRCNN', + num_stages=3, + pretrained='modelzoo://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + use_sigmoid_cls=True), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=True), + dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1], + reg_class_agnostic=True), + dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067], + reg_class_agnostic=True) + ], + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=81)) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False, + pos_balance_sampling=False, + neg_balance_thr=0), + allowed_border=0, + pos_weight=-1, + smoothl1_beta=1 / 9.0, + debug=False), + rcnn=[ + dict( + assigner=dict( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), + mask_size=28, + pos_weight=-1, + debug=False) + ], + loss_weight=[1, 0.5, 0.4]) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5), + keep_all_stages=False) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=True, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=True, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/cascade_mask_rcnn_r50_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py new file mode 100644 index 0000000..4b4fe16 --- /dev/null +++ b/configs/cascade_rcnn_r50_fpn_1x.py @@ -0,0 +1,209 @@ +# model settings +model = dict( + type='CascadeRCNN', + num_stages=3, + pretrained='modelzoo://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + use_sigmoid_cls=True), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=True), + dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1], + reg_class_agnostic=True), + dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067], + reg_class_agnostic=True) + ]) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False, + pos_balance_sampling=False, + neg_balance_thr=0), + allowed_border=0, + pos_weight=-1, + smoothl1_beta=1 / 9.0, + debug=False), + rcnn=[ + dict( + assigner=dict( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), + pos_weight=-1, + debug=False), + dict( + assigner=dict( + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), + pos_weight=-1, + debug=False), + dict( + assigner=dict( + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + ignore_iof_thr=-1), + sampler=dict( + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True, + pos_balance_sampling=False, + neg_balance_thr=0), + pos_weight=-1, + debug=False) + ], + loss_weight=[1, 0.5, 0.4]) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5), + keep_all_stages=False) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=False, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/cascade_rcnn_r50_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index 9b423bd..a1dda4d 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -122,3 +123,73 @@ class BBoxHead(nn.Module): nms_cfg.max_per_img) return det_bboxes, det_labels + + def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): + """Refine bboxes during training. + + Args: + rois (Tensor): Shape (n*bs, 5), where n is image number per GPU, + and bs is the sampled RoIs per image. + labels (Tensor): Shape (n*bs, ). + bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class). + pos_is_gts (list[Tensor]): Flags indicating if each positive bbox + is a gt bbox. + img_metas (list[dict]): Meta info of each image. + + Returns: + list[Tensor]: Refined bboxes of each image in a mini-batch. + """ + img_ids = rois[:, 0].long().unique(sorted=True) + assert img_ids.numel() == len(img_metas) + + bboxes_list = [] + for i in range(len(img_metas)): + inds = torch.nonzero(rois[:, 0] == i).squeeze() + num_rois = inds.numel() + + bboxes_ = rois[inds, 1:] + label_ = labels[inds] + bbox_pred_ = bbox_preds[inds] + img_meta_ = img_metas[i] + pos_is_gts_ = pos_is_gts[i] + + bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, + img_meta_) + # filter gt bboxes + pos_keep = 1 - pos_is_gts_ + keep_inds = pos_is_gts_.new_ones(num_rois) + keep_inds[:len(pos_is_gts_)] = pos_keep + + bboxes_list.append(bboxes[keep_inds]) + + return bboxes_list + + def regress_by_class(self, rois, label, bbox_pred, img_meta): + """Regress the bbox for the predicted class. Used in Cascade R-CNN. + + Args: + rois (Tensor): shape (n, 4) or (n, 5) + label (Tensor): shape (n, ) + bbox_pred (Tensor): shape (n, 4*(#class+1)) or (n, 4) + img_shape (Tensor): shape (3, ) + + Returns: + Tensor: same shape as input rois + """ + assert rois.size(1) == 4 or rois.size(1) == 5 + + if not self.reg_class_agnostic: + label = label * 4 + inds = torch.stack((label, label + 1, label + 2, label + 3), 1) + bbox_pred = torch.gather(bbox_pred, 1, inds) + assert bbox_pred.size(1) == 4 + + if rois.size(1) == 4: + new_rois = delta2bbox(rois, bbox_pred, self.target_means, + self.target_stds, img_meta['img_shape']) + else: + bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means, + self.target_stds, img_meta['img_shape']) + new_rois = torch.cat((rois[:, [0]], bboxes), dim=1) + + return new_rois diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index a784d5f..7cd084f 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -4,8 +4,9 @@ from .rpn import RPN from .fast_rcnn import FastRCNN from .faster_rcnn import FasterRCNN from .mask_rcnn import MaskRCNN +from .cascade_rcnn import CascadeRCNN __all__ = [ 'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', - 'MaskRCNN' + 'MaskRCNN', 'CascadeRCNN' ] diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py new file mode 100644 index 0000000..4da055a --- /dev/null +++ b/mmdet/models/detectors/cascade_rcnn.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn + +from .base import BaseDetector +from .test_mixins import RPNTestMixin +from .. import builder +from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, + merge_aug_masks) + + +class CascadeRCNN(BaseDetector, RPNTestMixin): + + def __init__(self, + num_stages, + backbone, + neck=None, + rpn_head=None, + bbox_roi_extractor=None, + bbox_head=None, + mask_roi_extractor=None, + mask_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + assert bbox_roi_extractor is not None + assert bbox_head is not None + super(CascadeRCNN, self).__init__() + + self.num_stages = num_stages + self.backbone = builder.build_backbone(backbone) + + if neck is not None: + self.neck = builder.build_neck(neck) + else: + raise NotImplementedError + + if rpn_head is not None: + self.rpn_head = builder.build_rpn_head(rpn_head) + + if bbox_head is not None: + self.bbox_roi_extractor = nn.ModuleList() + self.bbox_head = nn.ModuleList() + if not isinstance(bbox_roi_extractor, list): + bbox_roi_extractor = [ + bbox_roi_extractor for _ in range(num_stages) + ] + if not isinstance(bbox_head, list): + bbox_head = [bbox_head for _ in range(num_stages)] + assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages + for roi_extractor, head in zip(bbox_roi_extractor, bbox_head): + self.bbox_roi_extractor.append( + builder.build_roi_extractor(roi_extractor)) + self.bbox_head.append(builder.build_bbox_head(head)) + + if mask_head is not None: + self.mask_roi_extractor = nn.ModuleList() + self.mask_head = nn.ModuleList() + if not isinstance(mask_roi_extractor, list): + mask_roi_extractor = [ + mask_roi_extractor for _ in range(num_stages) + ] + if not isinstance(mask_head, list): + mask_head = [mask_head for _ in range(num_stages)] + assert len(mask_roi_extractor) == len(mask_head) == self.num_stages + for roi_extractor, head in zip(mask_roi_extractor, mask_head): + self.mask_roi_extractor.append( + builder.build_roi_extractor(roi_extractor)) + self.mask_head.append(builder.build_mask_head(head)) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.init_weights(pretrained=pretrained) + + @property + def with_rpn(self): + return hasattr(self, 'rpn_head') and self.rpn_head is not None + + def init_weights(self, pretrained=None): + super(CascadeRCNN, self).init_weights(pretrained) + self.backbone.init_weights(pretrained=pretrained) + if self.with_neck: + if isinstance(self.neck, nn.Sequential): + for m in self.neck: + m.init_weights() + else: + self.neck.init_weights() + if self.with_rpn: + self.rpn_head.init_weights() + for i in range(self.num_stages): + if self.with_bbox: + self.bbox_roi_extractor[i].init_weights() + self.bbox_head[i].init_weights() + if self.with_mask: + self.mask_roi_extractor[i].init_weights() + self.mask_head[i].init_weights() + + def extract_feat(self, img): + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def forward_train(self, + img, + img_meta, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + gt_masks=None, + proposals=None): + x = self.extract_feat(img) + + losses = dict() + + if self.with_rpn: + rpn_outs = self.rpn_head(x) + rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, + self.train_cfg.rpn) + rpn_losses = self.rpn_head.loss(*rpn_loss_inputs) + losses.update(rpn_losses) + + proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) + proposal_list = self.rpn_head.get_proposals(*proposal_inputs) + else: + proposal_list = proposals + + for i in range(self.num_stages): + rcnn_train_cfg = self.train_cfg.rcnn[i] + lw = self.train_cfg.loss_weight[i] + + # assign gts and sample proposals + assign_results, sampling_results = multi_apply( + assign_and_sample, + proposal_list, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + cfg=rcnn_train_cfg) + + # bbox head forward and loss + bbox_roi_extractor = self.bbox_roi_extractor[i] + bbox_head = self.bbox_head[i] + + rois = bbox2roi([res.bboxes for res in sampling_results]) + bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], + rois) + cls_score, bbox_pred = bbox_head(bbox_feats) + + bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes, + gt_labels, rcnn_train_cfg) + loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets) + for name, value in loss_bbox.items(): + losses['s{}.{}'.format(i, name)] = (value * lw if + 'loss' in name else value) + + # mask head forward and loss + if self.with_mask: + mask_roi_extractor = self.mask_roi_extractor[i] + mask_head = self.mask_head[i] + pos_rois = bbox2roi( + [res.pos_bboxes for res in sampling_results]) + mask_feats = mask_roi_extractor( + x[:mask_roi_extractor.num_inputs], pos_rois) + mask_pred = mask_head(mask_feats) + mask_targets = mask_head.get_target(sampling_results, gt_masks, + rcnn_train_cfg) + pos_labels = torch.cat( + [res.pos_gt_labels for res in sampling_results]) + loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels) + for name, value in loss_mask.items(): + losses['s{}.{}'.format(i, name)] = (value * lw + if 'loss' in name else + value) + + # refine bboxes + if i < self.num_stages - 1: + pos_is_gts = [res.pos_is_gt for res in sampling_results] + roi_labels = bbox_targets[0] # bbox_targets is a tuple + with torch.no_grad(): + proposal_list = bbox_head.refine_bboxes( + rois, roi_labels, bbox_pred, pos_is_gts, img_meta) + + return losses + + def simple_test(self, img, img_meta, proposals=None, rescale=False): + x = self.extract_feat(img) + proposal_list = self.simple_test_rpn( + x, img_meta, self.test_cfg.rpn) if proposals is None else proposals + + img_shape = img_meta[0]['img_shape'] + ori_shape = img_meta[0]['ori_shape'] + scale_factor = img_meta[0]['scale_factor'] + + # "ms" in variable names means multi-stage + ms_bbox_result = [] + ms_segm_result = [] + ms_scores = [] + rcnn_test_cfg = self.test_cfg.rcnn + + rois = bbox2roi(proposal_list) + for i in range(self.num_stages): + bbox_roi_extractor = self.bbox_roi_extractor[i] + bbox_head = self.bbox_head[i] + + bbox_feats = bbox_roi_extractor( + x[:len(bbox_roi_extractor.featmap_strides)], rois) + cls_score, bbox_pred = bbox_head(bbox_feats) + ms_scores.append(cls_score) + + if self.test_cfg.keep_all_stages: + det_bboxes, det_labels = bbox_head.get_det_bboxes( + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=rescale, + nms_cfg=rcnn_test_cfg) + bbox_result = bbox2result(det_bboxes, det_labels, + bbox_head.num_classes) + ms_bbox_result.append(bbox_result) + + if self.with_mask: + mask_block = self.mask_blocks[i] + mask_head = self.mask_heads[i] + if det_bboxes.shape[0] == 0: + segm_result = [ + [] for _ in range(mask_head.num_classes - 1) + ] + else: + _bboxes = (det_bboxes[:, :4] * img_shape[-1] + if rescale else det_bboxes) + mask_rois = bbox2roi([_bboxes]) + mask_feats = mask_block( + x[:len(mask_block.featmap_strides)], mask_rois) + mask_pred = mask_head(mask_feats) + segm_result = mask_head.get_seg_masks( + mask_pred, _bboxes, det_labels, rcnn_test_cfg, + ori_shape, scale_factor, rescale) + ms_segm_result.append(segm_result) + + if i < self.num_stages - 1: + bbox_label = cls_score.argmax(dim=1) + rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, + img_meta[0]) + + cls_score = sum(ms_scores) / float(len(ms_scores)) + det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=rescale, + nms_cfg=rcnn_test_cfg) + bbox_result = bbox2result(det_bboxes, det_labels, + self.bbox_head[-1].num_classes) + ms_bbox_result.append(bbox_result) + + if self.with_mask: + aug_masks = [] + for i in range(self.num_stages): + mask_roi_extractor = self.mask_roi_extractor[i] + mask_feats = mask_roi_extractor( + x[:len(mask_roi_extractor.featmap_strides)], mask_rois) + mask_pred = self.mask_head[i](mask_feats) + aug_masks.append(mask_pred.sigmoid().cpu().numpy()) + merged_masks = merge_aug_masks( + aug_masks, [img_meta[0]] * self.num_stages, self.test_cfg.rcnn) + segm_result = self.mask_head[-1].get_seg_masks( + merged_masks, _bboxes, det_labels, rcnn_test_cfg, ori_shape, + scale_factor, rescale) + ms_segm_result.append(segm_result) + + if not self.with_mask: + return ms_bbox_result + else: + return ms_bbox_result, ms_segm_result + + def aug_test(self, img, img_meta, proposals=None, rescale=False): + raise NotImplementedError + + def show_result(self, data, result, img_norm_cfg, **kwargs): + # TODO: show segmentation masks + if self.with_mask: + ms_bbox_result, ms_segm_result = result + else: + ms_bbox_result = result + super(CascadeRCNN, self).show_result(data, ms_bbox_result[-1], + img_norm_cfg, **kwargs) -- GitLab