From c10ae5d2ca158ae506b93198afa4926ea48bb40b Mon Sep 17 00:00:00 2001
From: Kai Chen <chenkaidev@gmail.com>
Date: Sun, 21 Oct 2018 20:21:35 +0800
Subject: [PATCH] add cascade rcnn

---
 configs/cascade_mask_rcnn_r50_fpn_1x.py | 224 ++++++++++++++++++
 configs/cascade_rcnn_r50_fpn_1x.py      | 209 +++++++++++++++++
 mmdet/models/bbox_heads/bbox_head.py    |  71 ++++++
 mmdet/models/detectors/__init__.py      |   3 +-
 mmdet/models/detectors/cascade_rcnn.py  | 291 ++++++++++++++++++++++++
 5 files changed, 797 insertions(+), 1 deletion(-)
 create mode 100644 configs/cascade_mask_rcnn_r50_fpn_1x.py
 create mode 100644 configs/cascade_rcnn_r50_fpn_1x.py
 create mode 100644 mmdet/models/detectors/cascade_rcnn.py

diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.py b/configs/cascade_mask_rcnn_r50_fpn_1x.py
new file mode 100644
index 0000000..ccda54b
--- /dev/null
+++ b/configs/cascade_mask_rcnn_r50_fpn_1x.py
@@ -0,0 +1,224 @@
+# model settings
+model = dict(
+    type='CascadeRCNN',
+    num_stages=3,
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        use_sigmoid_cls=True),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=[
+        dict(
+            type='SharedFCBBoxHead',
+            num_fcs=2,
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=True),
+        dict(
+            type='SharedFCBBoxHead',
+            num_fcs=2,
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.05, 0.05, 0.1, 0.1],
+            reg_class_agnostic=True),
+        dict(
+            type='SharedFCBBoxHead',
+            num_fcs=2,
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.033, 0.033, 0.067, 0.067],
+            reg_class_agnostic=True)
+    ],
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
+        allowed_border=0,
+        pos_weight=-1,
+        smoothl1_beta=1 / 9.0,
+        debug=False),
+    rcnn=[
+        dict(
+            assigner=dict(
+                pos_iou_thr=0.5,
+                neg_iou_thr=0.5,
+                min_pos_iou=0.5,
+                ignore_iof_thr=-1),
+            sampler=dict(
+                num=512,
+                pos_fraction=0.25,
+                neg_pos_ub=-1,
+                add_gt_as_proposals=True,
+                pos_balance_sampling=False,
+                neg_balance_thr=0),
+            mask_size=28,
+            pos_weight=-1,
+            debug=False),
+        dict(
+            assigner=dict(
+                pos_iou_thr=0.6,
+                neg_iou_thr=0.6,
+                min_pos_iou=0.6,
+                ignore_iof_thr=-1),
+            sampler=dict(
+                num=512,
+                pos_fraction=0.25,
+                neg_pos_ub=-1,
+                add_gt_as_proposals=True,
+                pos_balance_sampling=False,
+                neg_balance_thr=0),
+            mask_size=28,
+            pos_weight=-1,
+            debug=False),
+        dict(
+            assigner=dict(
+                pos_iou_thr=0.7,
+                neg_iou_thr=0.7,
+                min_pos_iou=0.7,
+                ignore_iof_thr=-1),
+            sampler=dict(
+                num=512,
+                pos_fraction=0.25,
+                neg_pos_ub=-1,
+                add_gt_as_proposals=True,
+                pos_balance_sampling=False,
+                neg_balance_thr=0),
+            mask_size=28,
+            pos_weight=-1,
+            debug=False)
+    ],
+    loss_weight=[1, 0.5, 0.4])
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5),
+    keep_all_stages=False)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/cascade_mask_rcnn_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/cascade_rcnn_r50_fpn_1x.py b/configs/cascade_rcnn_r50_fpn_1x.py
new file mode 100644
index 0000000..4b4fe16
--- /dev/null
+++ b/configs/cascade_rcnn_r50_fpn_1x.py
@@ -0,0 +1,209 @@
+# model settings
+model = dict(
+    type='CascadeRCNN',
+    num_stages=3,
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        use_sigmoid_cls=True),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=[
+        dict(
+            type='SharedFCBBoxHead',
+            num_fcs=2,
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.1, 0.1, 0.2, 0.2],
+            reg_class_agnostic=True),
+        dict(
+            type='SharedFCBBoxHead',
+            num_fcs=2,
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.05, 0.05, 0.1, 0.1],
+            reg_class_agnostic=True),
+        dict(
+            type='SharedFCBBoxHead',
+            num_fcs=2,
+            in_channels=256,
+            fc_out_channels=1024,
+            roi_feat_size=7,
+            num_classes=81,
+            target_means=[0., 0., 0., 0.],
+            target_stds=[0.033, 0.033, 0.067, 0.067],
+            reg_class_agnostic=True)
+    ])
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False,
+            pos_balance_sampling=False,
+            neg_balance_thr=0),
+        allowed_border=0,
+        pos_weight=-1,
+        smoothl1_beta=1 / 9.0,
+        debug=False),
+    rcnn=[
+        dict(
+            assigner=dict(
+                pos_iou_thr=0.5,
+                neg_iou_thr=0.5,
+                min_pos_iou=0.5,
+                ignore_iof_thr=-1),
+            sampler=dict(
+                num=512,
+                pos_fraction=0.25,
+                neg_pos_ub=-1,
+                add_gt_as_proposals=True,
+                pos_balance_sampling=False,
+                neg_balance_thr=0),
+            pos_weight=-1,
+            debug=False),
+        dict(
+            assigner=dict(
+                pos_iou_thr=0.6,
+                neg_iou_thr=0.6,
+                min_pos_iou=0.6,
+                ignore_iof_thr=-1),
+            sampler=dict(
+                num=512,
+                pos_fraction=0.25,
+                neg_pos_ub=-1,
+                add_gt_as_proposals=True,
+                pos_balance_sampling=False,
+                neg_balance_thr=0),
+            pos_weight=-1,
+            debug=False),
+        dict(
+            assigner=dict(
+                pos_iou_thr=0.7,
+                neg_iou_thr=0.7,
+                min_pos_iou=0.7,
+                ignore_iof_thr=-1),
+            sampler=dict(
+                num=512,
+                pos_fraction=0.25,
+                neg_pos_ub=-1,
+                add_gt_as_proposals=True,
+                pos_balance_sampling=False,
+                neg_balance_thr=0),
+            pos_weight=-1,
+            debug=False)
+    ],
+    loss_weight=[1, 0.5, 0.4])
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5),
+    keep_all_stages=False)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/cascade_rcnn_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 9b423bd..a1dda4d 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -1,3 +1,4 @@
+import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
@@ -122,3 +123,73 @@ class BBoxHead(nn.Module):
                 nms_cfg.max_per_img)
 
             return det_bboxes, det_labels
+
+    def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
+        """Refine bboxes during training.
+
+        Args:
+            rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
+                and bs is the sampled RoIs per image.
+            labels (Tensor): Shape (n*bs, ).
+            bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
+            pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
+                is a gt bbox.
+            img_metas (list[dict]): Meta info of each image.
+
+        Returns:
+            list[Tensor]: Refined bboxes of each image in a mini-batch.
+        """
+        img_ids = rois[:, 0].long().unique(sorted=True)
+        assert img_ids.numel() == len(img_metas)
+
+        bboxes_list = []
+        for i in range(len(img_metas)):
+            inds = torch.nonzero(rois[:, 0] == i).squeeze()
+            num_rois = inds.numel()
+
+            bboxes_ = rois[inds, 1:]
+            label_ = labels[inds]
+            bbox_pred_ = bbox_preds[inds]
+            img_meta_ = img_metas[i]
+            pos_is_gts_ = pos_is_gts[i]
+
+            bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
+                                           img_meta_)
+            # filter gt bboxes
+            pos_keep = 1 - pos_is_gts_
+            keep_inds = pos_is_gts_.new_ones(num_rois)
+            keep_inds[:len(pos_is_gts_)] = pos_keep
+
+            bboxes_list.append(bboxes[keep_inds])
+
+        return bboxes_list
+
+    def regress_by_class(self, rois, label, bbox_pred, img_meta):
+        """Regress the bbox for the predicted class. Used in Cascade R-CNN.
+
+        Args:
+            rois (Tensor): shape (n, 4) or (n, 5)
+            label (Tensor): shape (n, )
+            bbox_pred (Tensor): shape (n, 4*(#class+1)) or (n, 4)
+            img_shape (Tensor): shape (3, )
+
+        Returns:
+            Tensor: same shape as input rois
+        """
+        assert rois.size(1) == 4 or rois.size(1) == 5
+
+        if not self.reg_class_agnostic:
+            label = label * 4
+            inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
+            bbox_pred = torch.gather(bbox_pred, 1, inds)
+        assert bbox_pred.size(1) == 4
+
+        if rois.size(1) == 4:
+            new_rois = delta2bbox(rois, bbox_pred, self.target_means,
+                                  self.target_stds, img_meta['img_shape'])
+        else:
+            bboxes = delta2bbox(rois[:, 1:], bbox_pred, self.target_means,
+                                self.target_stds, img_meta['img_shape'])
+            new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
+
+        return new_rois
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index a784d5f..7cd084f 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -4,8 +4,9 @@ from .rpn import RPN
 from .fast_rcnn import FastRCNN
 from .faster_rcnn import FasterRCNN
 from .mask_rcnn import MaskRCNN
+from .cascade_rcnn import CascadeRCNN
 
 __all__ = [
     'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
-    'MaskRCNN'
+    'MaskRCNN', 'CascadeRCNN'
 ]
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
new file mode 100644
index 0000000..4da055a
--- /dev/null
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -0,0 +1,291 @@
+import torch
+import torch.nn as nn
+
+from .base import BaseDetector
+from .test_mixins import RPNTestMixin
+from .. import builder
+from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
+                        merge_aug_masks)
+
+
+class CascadeRCNN(BaseDetector, RPNTestMixin):
+
+    def __init__(self,
+                 num_stages,
+                 backbone,
+                 neck=None,
+                 rpn_head=None,
+                 bbox_roi_extractor=None,
+                 bbox_head=None,
+                 mask_roi_extractor=None,
+                 mask_head=None,
+                 train_cfg=None,
+                 test_cfg=None,
+                 pretrained=None):
+        assert bbox_roi_extractor is not None
+        assert bbox_head is not None
+        super(CascadeRCNN, self).__init__()
+
+        self.num_stages = num_stages
+        self.backbone = builder.build_backbone(backbone)
+
+        if neck is not None:
+            self.neck = builder.build_neck(neck)
+        else:
+            raise NotImplementedError
+
+        if rpn_head is not None:
+            self.rpn_head = builder.build_rpn_head(rpn_head)
+
+        if bbox_head is not None:
+            self.bbox_roi_extractor = nn.ModuleList()
+            self.bbox_head = nn.ModuleList()
+            if not isinstance(bbox_roi_extractor, list):
+                bbox_roi_extractor = [
+                    bbox_roi_extractor for _ in range(num_stages)
+                ]
+            if not isinstance(bbox_head, list):
+                bbox_head = [bbox_head for _ in range(num_stages)]
+            assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
+            for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
+                self.bbox_roi_extractor.append(
+                    builder.build_roi_extractor(roi_extractor))
+                self.bbox_head.append(builder.build_bbox_head(head))
+
+        if mask_head is not None:
+            self.mask_roi_extractor = nn.ModuleList()
+            self.mask_head = nn.ModuleList()
+            if not isinstance(mask_roi_extractor, list):
+                mask_roi_extractor = [
+                    mask_roi_extractor for _ in range(num_stages)
+                ]
+            if not isinstance(mask_head, list):
+                mask_head = [mask_head for _ in range(num_stages)]
+            assert len(mask_roi_extractor) == len(mask_head) == self.num_stages
+            for roi_extractor, head in zip(mask_roi_extractor, mask_head):
+                self.mask_roi_extractor.append(
+                    builder.build_roi_extractor(roi_extractor))
+                self.mask_head.append(builder.build_mask_head(head))
+
+        self.train_cfg = train_cfg
+        self.test_cfg = test_cfg
+
+        self.init_weights(pretrained=pretrained)
+
+    @property
+    def with_rpn(self):
+        return hasattr(self, 'rpn_head') and self.rpn_head is not None
+
+    def init_weights(self, pretrained=None):
+        super(CascadeRCNN, self).init_weights(pretrained)
+        self.backbone.init_weights(pretrained=pretrained)
+        if self.with_neck:
+            if isinstance(self.neck, nn.Sequential):
+                for m in self.neck:
+                    m.init_weights()
+            else:
+                self.neck.init_weights()
+        if self.with_rpn:
+            self.rpn_head.init_weights()
+        for i in range(self.num_stages):
+            if self.with_bbox:
+                self.bbox_roi_extractor[i].init_weights()
+                self.bbox_head[i].init_weights()
+            if self.with_mask:
+                self.mask_roi_extractor[i].init_weights()
+                self.mask_head[i].init_weights()
+
+    def extract_feat(self, img):
+        x = self.backbone(img)
+        if self.with_neck:
+            x = self.neck(x)
+        return x
+
+    def forward_train(self,
+                      img,
+                      img_meta,
+                      gt_bboxes,
+                      gt_bboxes_ignore,
+                      gt_labels,
+                      gt_masks=None,
+                      proposals=None):
+        x = self.extract_feat(img)
+
+        losses = dict()
+
+        if self.with_rpn:
+            rpn_outs = self.rpn_head(x)
+            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
+                                          self.train_cfg.rpn)
+            rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
+            losses.update(rpn_losses)
+
+            proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
+            proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
+        else:
+            proposal_list = proposals
+
+        for i in range(self.num_stages):
+            rcnn_train_cfg = self.train_cfg.rcnn[i]
+            lw = self.train_cfg.loss_weight[i]
+
+            # assign gts and sample proposals
+            assign_results, sampling_results = multi_apply(
+                assign_and_sample,
+                proposal_list,
+                gt_bboxes,
+                gt_bboxes_ignore,
+                gt_labels,
+                cfg=rcnn_train_cfg)
+
+            # bbox head forward and loss
+            bbox_roi_extractor = self.bbox_roi_extractor[i]
+            bbox_head = self.bbox_head[i]
+
+            rois = bbox2roi([res.bboxes for res in sampling_results])
+            bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
+                                            rois)
+            cls_score, bbox_pred = bbox_head(bbox_feats)
+
+            bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes,
+                                                gt_labels, rcnn_train_cfg)
+            loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets)
+            for name, value in loss_bbox.items():
+                losses['s{}.{}'.format(i, name)] = (value * lw if
+                                                    'loss' in name else value)
+
+            # mask head forward and loss
+            if self.with_mask:
+                mask_roi_extractor = self.mask_roi_extractor[i]
+                mask_head = self.mask_head[i]
+                pos_rois = bbox2roi(
+                    [res.pos_bboxes for res in sampling_results])
+                mask_feats = mask_roi_extractor(
+                    x[:mask_roi_extractor.num_inputs], pos_rois)
+                mask_pred = mask_head(mask_feats)
+                mask_targets = mask_head.get_target(sampling_results, gt_masks,
+                                                    rcnn_train_cfg)
+                pos_labels = torch.cat(
+                    [res.pos_gt_labels for res in sampling_results])
+                loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
+                for name, value in loss_mask.items():
+                    losses['s{}.{}'.format(i, name)] = (value * lw
+                                                        if 'loss' in name else
+                                                        value)
+
+            # refine bboxes
+            if i < self.num_stages - 1:
+                pos_is_gts = [res.pos_is_gt for res in sampling_results]
+                roi_labels = bbox_targets[0]  # bbox_targets is a tuple
+                with torch.no_grad():
+                    proposal_list = bbox_head.refine_bboxes(
+                        rois, roi_labels, bbox_pred, pos_is_gts, img_meta)
+
+        return losses
+
+    def simple_test(self, img, img_meta, proposals=None, rescale=False):
+        x = self.extract_feat(img)
+        proposal_list = self.simple_test_rpn(
+            x, img_meta, self.test_cfg.rpn) if proposals is None else proposals
+
+        img_shape = img_meta[0]['img_shape']
+        ori_shape = img_meta[0]['ori_shape']
+        scale_factor = img_meta[0]['scale_factor']
+
+        # "ms" in variable names means multi-stage
+        ms_bbox_result = []
+        ms_segm_result = []
+        ms_scores = []
+        rcnn_test_cfg = self.test_cfg.rcnn
+
+        rois = bbox2roi(proposal_list)
+        for i in range(self.num_stages):
+            bbox_roi_extractor = self.bbox_roi_extractor[i]
+            bbox_head = self.bbox_head[i]
+
+            bbox_feats = bbox_roi_extractor(
+                x[:len(bbox_roi_extractor.featmap_strides)], rois)
+            cls_score, bbox_pred = bbox_head(bbox_feats)
+            ms_scores.append(cls_score)
+
+            if self.test_cfg.keep_all_stages:
+                det_bboxes, det_labels = bbox_head.get_det_bboxes(
+                    rois,
+                    cls_score,
+                    bbox_pred,
+                    img_shape,
+                    scale_factor,
+                    rescale=rescale,
+                    nms_cfg=rcnn_test_cfg)
+                bbox_result = bbox2result(det_bboxes, det_labels,
+                                          bbox_head.num_classes)
+                ms_bbox_result.append(bbox_result)
+
+                if self.with_mask:
+                    mask_block = self.mask_blocks[i]
+                    mask_head = self.mask_heads[i]
+                    if det_bboxes.shape[0] == 0:
+                        segm_result = [
+                            [] for _ in range(mask_head.num_classes - 1)
+                        ]
+                    else:
+                        _bboxes = (det_bboxes[:, :4] * img_shape[-1]
+                                   if rescale else det_bboxes)
+                        mask_rois = bbox2roi([_bboxes])
+                        mask_feats = mask_block(
+                            x[:len(mask_block.featmap_strides)], mask_rois)
+                        mask_pred = mask_head(mask_feats)
+                        segm_result = mask_head.get_seg_masks(
+                            mask_pred, _bboxes, det_labels, rcnn_test_cfg,
+                            ori_shape, scale_factor, rescale)
+                    ms_segm_result.append(segm_result)
+
+            if i < self.num_stages - 1:
+                bbox_label = cls_score.argmax(dim=1)
+                rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred,
+                                                  img_meta[0])
+
+        cls_score = sum(ms_scores) / float(len(ms_scores))
+        det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes(
+            rois,
+            cls_score,
+            bbox_pred,
+            img_shape,
+            scale_factor,
+            rescale=rescale,
+            nms_cfg=rcnn_test_cfg)
+        bbox_result = bbox2result(det_bboxes, det_labels,
+                                  self.bbox_head[-1].num_classes)
+        ms_bbox_result.append(bbox_result)
+
+        if self.with_mask:
+            aug_masks = []
+            for i in range(self.num_stages):
+                mask_roi_extractor = self.mask_roi_extractor[i]
+                mask_feats = mask_roi_extractor(
+                    x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+                mask_pred = self.mask_head[i](mask_feats)
+                aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+            merged_masks = merge_aug_masks(
+                aug_masks, [img_meta[0]] * self.num_stages, self.test_cfg.rcnn)
+            segm_result = self.mask_head[-1].get_seg_masks(
+                merged_masks, _bboxes, det_labels, rcnn_test_cfg, ori_shape,
+                scale_factor, rescale)
+            ms_segm_result.append(segm_result)
+
+        if not self.with_mask:
+            return ms_bbox_result
+        else:
+            return ms_bbox_result, ms_segm_result
+
+    def aug_test(self, img, img_meta, proposals=None, rescale=False):
+        raise NotImplementedError
+
+    def show_result(self, data, result, img_norm_cfg, **kwargs):
+        # TODO: show segmentation masks
+        if self.with_mask:
+            ms_bbox_result, ms_segm_result = result
+        else:
+            ms_bbox_result = result
+        super(CascadeRCNN, self).show_result(data, ms_bbox_result[-1],
+                                             img_norm_cfg, **kwargs)
-- 
GitLab