diff --git a/configs/ms_rcnn/README.md b/configs/ms_rcnn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3b08ad15e7ae833570729021b9de0bda6569b332 --- /dev/null +++ b/configs/ms_rcnn/README.md @@ -0,0 +1,23 @@ +# Mask Scoring R-CNN + +## Introduction + +``` +@inproceedings{huang2019msrcnn, + title={Mask Scoring R-CNN}, + author={Zhaojin Huang and Lichao Huang and Yongchao Gong and Chang Huang and Xinggang Wang}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + year={2019}, +} +``` + +## Results and Models + +| Backbone | style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | mask AP | Download | +|:-------------:|:----------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:-------:|:--------:| +| R-50-FPN | caffe | 1x | 4.3 | 0.537 | 10.1 | 37.4 | 35.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/ms-rcnn/ms_rcnn_r50_caffe_fpn_1x_20190624-619934b5.pth) | +| R-50-FPN | caffe | 2x | - | - | - | 38.2 | 35.9 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_r50_caffe_fpn_2x_20190525-a07be31e.pth) | +| R-101-FPN | caffe | 1x | 6.2 | 0.682 | 9.1 | 39.8 | 37.2 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/ms-rcnn/ms_rcnn_r101_caffe_fpn_1x_20190624-677a5548.pth) | +| R-101-FPN | caffe | 2x | - | - | - | 40.7 | 37.8 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_r101_caffe_fpn_2x_20190525-4aee1528.pth) | +| R-X101-64x4d | pytorch | 1x | 10.5 | 1.214 | 6.4 | 42.2 | 39.2 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_x101_64x4d_fpn_1x_20190525-026b16ae.pth) | +| R-X101-64x4d | pytorch | 2x | - | - | - | 42.2 | 38.9 | [model](https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/ms-rcnn/ms_rcnn_x101_64x4d_fpn_2x_20190525-c044c25a.pth) | diff --git a/configs/ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py b/configs/ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..682d308a937675bc2a03c59982ab853bd7c39b15 --- /dev/null +++ b/configs/ms_rcnn/ms_rcnn_r101_caffe_fpn_1x.py @@ -0,0 +1,192 @@ +# model settings +model = dict( + type='MaskScoringRCNN', + pretrained='open-mmlab://resnet101_caffe', + backbone=dict( + type='ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + style='caffe'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=81, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)), + mask_iou_head=dict( + type='MaskIoUHead', + num_convs=4, + num_fcs=2, + roi_feat_size=14, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + num_classes=81)) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + mask_thr_binary=0.5, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100, + mask_thr_binary=0.5)) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=True, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/ms_rcnn_r101_caffe_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py b/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..16ebde50bfcc25b79447dc894352a8d48638d40b --- /dev/null +++ b/configs/ms_rcnn/ms_rcnn_r50_caffe_fpn_1x.py @@ -0,0 +1,192 @@ +# model settings +model = dict( + type='MaskScoringRCNN', + pretrained='open-mmlab://resnet50_caffe', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + style='caffe'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=81, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)), + mask_iou_head=dict( + type='MaskIoUHead', + num_convs=4, + num_fcs=2, + roi_feat_size=14, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + num_classes=81)) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + mask_thr_binary=0.5, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100, + mask_thr_binary=0.5)) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=True, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/ms_rcnn_r50_caffe_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py b/configs/ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b720332cf33263295dcfeeae0d85b793e5166d --- /dev/null +++ b/configs/ms_rcnn/ms_rcnn_x101_64x4d_fpn_1x.py @@ -0,0 +1,193 @@ +# model settings +model = dict( + type='MaskScoringRCNN', + pretrained='open-mmlab://resnext101_64x4d', + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=81, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)), + mask_iou_head=dict( + type='MaskIoUHead', + num_convs=4, + num_fcs=2, + roi_feat_size=14, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + num_classes=81)) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + mask_thr_binary=0.5, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100, + mask_thr_binary=0.5)) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=True, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/ms_rcnn_x101_64x4d_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 824833d8823d924c75ebe68c3ffca4c80270a981..2424a8bd06cb0b43b522ed98c90d3655205b58ed 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -9,9 +9,10 @@ from .cascade_rcnn import CascadeRCNN from .htc import HybridTaskCascade from .retinanet import RetinaNet from .fcos import FCOS +from .mask_scoring_rcnn import MaskScoringRCNN __all__ = [ 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', - 'RetinaNet', 'FCOS' + 'RetinaNet', 'FCOS', 'MaskScoringRCNN' ] diff --git a/mmdet/models/detectors/mask_scoring_rcnn.py b/mmdet/models/detectors/mask_scoring_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..9153bcd147b5c4f9658fb7d09403b78b8ba4e181 --- /dev/null +++ b/mmdet/models/detectors/mask_scoring_rcnn.py @@ -0,0 +1,197 @@ +import torch + +from mmdet.core import bbox2roi, build_assigner, build_sampler +from .two_stage import TwoStageDetector +from .. import builder +from ..registry import DETECTORS + + +@DETECTORS.register_module +class MaskScoringRCNN(TwoStageDetector): + """Mask Scoring RCNN. + + https://arxiv.org/abs/1903.00241 + """ + + def __init__(self, + backbone, + rpn_head, + bbox_roi_extractor, + bbox_head, + mask_roi_extractor, + mask_head, + train_cfg, + test_cfg, + neck=None, + shared_head=None, + mask_iou_head=None, + pretrained=None): + super(MaskScoringRCNN, self).__init__( + backbone=backbone, + neck=neck, + shared_head=shared_head, + rpn_head=rpn_head, + bbox_roi_extractor=bbox_roi_extractor, + bbox_head=bbox_head, + mask_roi_extractor=mask_roi_extractor, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained) + + self.mask_iou_head = builder.build_head(mask_iou_head) + self.mask_iou_head.init_weights() + + # TODO: refactor forward_train in two stage to reduce code redundancy + def forward_train(self, + img, + img_meta, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + gt_masks=None, + proposals=None): + x = self.extract_feat(img) + + losses = dict() + + # RPN forward and loss + if self.with_rpn: + rpn_outs = self.rpn_head(x) + rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, + self.train_cfg.rpn) + rpn_losses = self.rpn_head.loss( + *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + losses.update(rpn_losses) + + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + proposal_inputs = rpn_outs + (img_meta, proposal_cfg) + proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) + else: + proposal_list = proposals + + # assign gts and sample proposals + if self.with_bbox or self.with_mask: + bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) + bbox_sampler = build_sampler( + self.train_cfg.rcnn.sampler, context=self) + num_imgs = img.size(0) + if gt_bboxes_ignore is None: + gt_bboxes_ignore = [None for _ in range(num_imgs)] + sampling_results = [] + for i in range(num_imgs): + assign_result = bbox_assigner.assign(proposal_list[i], + gt_bboxes[i], + gt_bboxes_ignore[i], + gt_labels[i]) + sampling_result = bbox_sampler.sample( + assign_result, + proposal_list[i], + gt_bboxes[i], + gt_labels[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + # bbox head forward and loss + if self.with_bbox: + rois = bbox2roi([res.bboxes for res in sampling_results]) + # TODO: a more flexible way to decide which feature maps to use + bbox_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + if self.with_shared_head: + bbox_feats = self.shared_head(bbox_feats) + cls_score, bbox_pred = self.bbox_head(bbox_feats) + + bbox_targets = self.bbox_head.get_target(sampling_results, + gt_bboxes, gt_labels, + self.train_cfg.rcnn) + loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, + *bbox_targets) + losses.update(loss_bbox) + + # mask head forward and loss + if self.with_mask: + if not self.share_roi_extractor: + pos_rois = bbox2roi( + [res.pos_bboxes for res in sampling_results]) + mask_feats = self.mask_roi_extractor( + x[:self.mask_roi_extractor.num_inputs], pos_rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + else: + pos_inds = [] + device = bbox_feats.device + for res in sampling_results: + pos_inds.append( + torch.ones( + res.pos_bboxes.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds.append( + torch.zeros( + res.neg_bboxes.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds = torch.cat(pos_inds) + mask_feats = bbox_feats[pos_inds] + mask_pred = self.mask_head(mask_feats) + + mask_targets = self.mask_head.get_target(sampling_results, + gt_masks, + self.train_cfg.rcnn) + pos_labels = torch.cat( + [res.pos_gt_labels for res in sampling_results]) + loss_mask = self.mask_head.loss(mask_pred, mask_targets, + pos_labels) + losses.update(loss_mask) + + # mask iou head forward and loss + pos_mask_pred = mask_pred[range(mask_pred.size(0)), pos_labels] + mask_iou_pred = self.mask_iou_head(mask_feats, pos_mask_pred) + pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0) + ), pos_labels] + mask_iou_targets = self.mask_iou_head.get_target( + sampling_results, gt_masks, pos_mask_pred, mask_targets, + self.train_cfg.rcnn) + loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred, + mask_iou_targets) + losses.update(loss_mask_iou) + return losses + + def simple_test_mask(self, + x, + img_meta, + det_bboxes, + det_labels, + rescale=False): + # image shape of the first image in the batch (only one) + ori_shape = img_meta[0]['ori_shape'] + scale_factor = img_meta[0]['scale_factor'] + + if det_bboxes.shape[0] == 0: + segm_result = [[] for _ in range(self.mask_head.num_classes - 1)] + mask_scores = [[] for _ in range(self.mask_head.num_classes - 1)] + else: + # if det_bboxes is rescaled to the original image size, we need to + # rescale it back to the testing scale to obtain RoIs. + _bboxes = ( + det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) + mask_rois = bbox2roi([_bboxes]) + mask_feats = self.mask_roi_extractor( + x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + mask_pred = self.mask_head(mask_feats) + segm_result = self.mask_head.get_seg_masks(mask_pred, _bboxes, + det_labels, + self.test_cfg.rcnn, + ori_shape, scale_factor, + rescale) + # get mask scores with mask iou head + mask_iou_pred = self.mask_iou_head( + mask_feats, + mask_pred[range(det_labels.size(0)), det_labels + 1]) + mask_scores = self.mask_iou_head.get_mask_scores( + mask_iou_pred, det_bboxes, det_labels) + return segm_result, mask_scores diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py index 45920dd3c209f1f756f0294bb1465a3faca61dcd..531e5f1b85883230e919c8fd1a156f000251a88e 100644 --- a/mmdet/models/losses/__init__.py +++ b/mmdet/models/losses/__init__.py @@ -5,6 +5,7 @@ from .focal_loss import sigmoid_focal_loss, FocalLoss from .smooth_l1_loss import smooth_l1_loss, SmoothL1Loss from .ghm_loss import GHMC, GHMR from .balanced_l1_loss import balanced_l1_loss, BalancedL1Loss +from .mse_loss import mse_loss, MSELoss from .iou_loss import iou_loss, bounded_iou_loss, IoULoss, BoundedIoULoss from .utils import reduce_loss, weight_reduce_loss, weighted_loss @@ -12,7 +13,7 @@ __all__ = [ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss', 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss', - 'BalancedL1Loss', 'iou_loss', 'bounded_iou_loss', 'IoULoss', - 'BoundedIoULoss', 'GHMC', 'GHMR', 'reduce_loss', 'weight_reduce_loss', - 'weighted_loss' + 'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss', + 'IoULoss', 'BoundedIoULoss', 'GHMC', 'GHMR', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss' ] diff --git a/mmdet/models/losses/mse_loss.py b/mmdet/models/losses/mse_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a50f4592017e4413989874ea19342edad94fead7 --- /dev/null +++ b/mmdet/models/losses/mse_loss.py @@ -0,0 +1,25 @@ +import torch.nn as nn +import torch.nn.functional as F + +from .utils import weighted_loss +from ..registry import LOSSES + +mse_loss = weighted_loss(F.mse_loss) + + +@LOSSES.register_module +class MSELoss(nn.Module): + + def __init__(self, reduction='mean', loss_weight=1.0): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, pred, target, weight=None, avg_factor=None): + loss = self.loss_weight * mse_loss( + pred, + target, + weight, + reduction=self.reduction, + avg_factor=avg_factor) + return loss diff --git a/mmdet/models/mask_heads/__init__.py b/mmdet/models/mask_heads/__init__.py index 899e4645a190d9912722305b86f9b166ed946d0f..ce1f3e7f3d5c8a6a05dc0a11ccfb410c67f02670 100644 --- a/mmdet/models/mask_heads/__init__.py +++ b/mmdet/models/mask_heads/__init__.py @@ -1,5 +1,6 @@ from .fcn_mask_head import FCNMaskHead -from .htc_mask_head import HTCMaskHead from .fused_semantic_head import FusedSemanticHead +from .htc_mask_head import HTCMaskHead +from .maskiou_head import MaskIoUHead -__all__ = ['FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead'] +__all__ = ['FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'MaskIoUHead'] diff --git a/mmdet/models/mask_heads/maskiou_head.py b/mmdet/models/mask_heads/maskiou_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa1764345d1ef399ee66d9da11760086db0076f --- /dev/null +++ b/mmdet/models/mask_heads/maskiou_head.py @@ -0,0 +1,181 @@ +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import kaiming_init, normal_init + +from ..builder import build_loss +from ..registry import HEADS + + +@HEADS.register_module +class MaskIoUHead(nn.Module): + """Mask IoU Head. + + This head predicts the IoU of predicted masks and corresponding gt masks. + """ + + def __init__(self, + num_convs=4, + num_fcs=2, + roi_feat_size=14, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + num_classes=81, + loss_iou=dict(type='MSELoss', loss_weight=0.5)): + super(MaskIoUHead, self).__init__() + self.in_channels = in_channels + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.num_classes = num_classes + + self.convs = nn.ModuleList() + for i in range(num_convs): + if i == 0: + # concatenation of mask feature and mask prediction + in_channels = self.in_channels + 1 + else: + in_channels = self.conv_out_channels + stride = 2 if i == num_convs - 1 else 1 + self.convs.append( + nn.Conv2d( + in_channels, + self.conv_out_channels, + 3, + stride=stride, + padding=1)) + + self.fcs = nn.ModuleList() + for i in range(num_fcs): + in_channels = self.conv_out_channels * ( + roi_feat_size // 2)**2 if i == 0 else self.fc_out_channels + self.fcs.append(nn.Linear(in_channels, self.fc_out_channels)) + + self.fc_mask_iou = nn.Linear(self.fc_out_channels, self.num_classes) + self.relu = nn.ReLU() + self.max_pool = nn.MaxPool2d(2, 2) + self.loss_iou = build_loss(loss_iou) + + def init_weights(self): + for conv in self.convs: + kaiming_init(conv) + for fc in self.fcs: + kaiming_init( + fc, + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform') + normal_init(self.fc_mask_iou, std=0.01) + + def forward(self, mask_feat, mask_pred): + mask_pred = mask_pred.sigmoid() + mask_pred_pooled = self.max_pool(mask_pred.unsqueeze(1)) + + x = torch.cat((mask_feat, mask_pred_pooled), 1) + + for conv in self.convs: + x = self.relu(conv(x)) + x = x.view(x.size(0), -1) + for fc in self.fcs: + x = self.relu(fc(x)) + mask_iou = self.fc_mask_iou(x) + return mask_iou + + def loss(self, mask_iou_pred, mask_iou_targets): + pos_inds = mask_iou_targets > 0 + if pos_inds.sum() > 0: + loss_mask_iou = self.loss_iou(mask_iou_pred[pos_inds], + mask_iou_targets[pos_inds]) + else: + loss_mask_iou = mask_iou_pred * 0 + return dict(loss_mask_iou=loss_mask_iou) + + def get_target(self, sampling_results, gt_masks, mask_pred, mask_targets, + rcnn_train_cfg): + """Compute target of mask IoU. + + Mask IoU target is the IoU of the predicted mask (inside a bbox) and + the gt mask of corresponding gt mask (the whole instance). + The intersection area is computed inside the bbox, and the gt mask area + is computed with two steps, firstly we compute the gt area inside the + bbox, then divide it by the area ratio of gt area inside the bbox and + the gt area of the whole instance. + + Args: + sampling_results (list[:obj:`SamplingResult`]): sampling results. + gt_masks (list[ndarray]): Gt masks (the whole instance) of each + image, binary maps with the same shape of the input image. + mask_pred (Tensor): Predicted masks of each positive proposal, + shape (num_pos, h, w). + mask_targets (Tensor): Gt mask of each positive proposal, + binary map of the shape (num_pos, h, w). + rcnn_train_cfg (dict): Training config for R-CNN part. + + Returns: + Tensor: mask iou target (length == num positive). + """ + pos_proposals = [res.pos_bboxes for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + + # compute the area ratio of gt areas inside the proposals and + # the whole instance + area_ratios = map(self._get_area_ratio, pos_proposals, + pos_assigned_gt_inds, gt_masks) + area_ratios = torch.cat(list(area_ratios)) + assert mask_targets.size(0) == area_ratios.size(0) + + mask_pred = (mask_pred > rcnn_train_cfg.mask_thr_binary).float() + mask_pred_areas = mask_pred.sum((-1, -2)) + + # mask_pred and mask_targets are binary maps + overlap_areas = (mask_pred * mask_targets).sum((-1, -2)) + + # compute the mask area of the whole instance + gt_full_areas = mask_targets.sum((-1, -2)) / (area_ratios + 1e-7) + + mask_iou_targets = overlap_areas / ( + mask_pred_areas + gt_full_areas - overlap_areas) + return mask_iou_targets + + def _get_area_ratio(self, pos_proposals, pos_assigned_gt_inds, gt_masks): + """Compute area ratio of the gt mask inside the proposal and the gt + mask of the corresponding instance""" + num_pos = pos_proposals.size(0) + if num_pos > 0: + area_ratios = [] + proposals_np = pos_proposals.cpu().numpy() + pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() + # compute mask areas of gt instances (batch processing for speedup) + gt_instance_mask_area = gt_masks.sum((-1, -2)) + for i in range(num_pos): + gt_mask = gt_masks[pos_assigned_gt_inds[i]] + + # crop the gt mask inside the proposal + x1, y1, x2, y2 = proposals_np[i, :].astype(np.int32) + gt_mask_in_proposal = gt_mask[y1:y2 + 1, x1:x2 + 1] + + ratio = gt_mask_in_proposal.sum() / ( + gt_instance_mask_area[pos_assigned_gt_inds[i]] + 1e-7) + area_ratios.append(ratio) + area_ratios = torch.from_numpy(np.stack(area_ratios)).float().to( + pos_proposals.device) + else: + area_ratios = pos_proposals.new_zeros((0, )) + return area_ratios + + def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels): + """Get the mask scores. + + mask_score = bbox_score * mask_iou + """ + inds = range(det_labels.size(0)) + mask_scores = mask_iou_pred[inds, det_labels + + 1] * det_bboxes[inds, -1] + mask_scores = mask_scores.cpu().numpy() + det_labels = det_labels.cpu().numpy() + return [ + mask_scores[det_labels == i] for i in range(self.num_classes - 1) + ]