diff --git a/configs/double_heads/dh_faster_rcnn_r50_fpn_1x.py b/configs/double_heads/dh_faster_rcnn_r50_fpn_1x.py new file mode 100644 index 0000000000000000000000000000000000000000..cef46570279d5835ddad67b2569ecc5ec8ccfad4 --- /dev/null +++ b/configs/double_heads/dh_faster_rcnn_r50_fpn_1x.py @@ -0,0 +1,170 @@ +# model settings +model = dict( + type='DoubleHeadRCNN', + pretrained='modelzoo://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + reg_roi_scale_factor=1.3, + bbox_head=dict( + type='DoubleConvFCBBoxHead', + num_convs=4, + num_fcs=2, + in_channels=256, + conv_out_channels=1024, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05) +) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=False, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/dh_faster_rcnn_r50_fpn_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/models/bbox_heads/__init__.py b/mmdet/models/bbox_heads/__init__.py index dac10bdc144636ccb3953db5dfccfc1a583ceec2..a668bdb0182361b5213adb877a21246e23ab4ede 100644 --- a/mmdet/models/bbox_heads/__init__.py +++ b/mmdet/models/bbox_heads/__init__.py @@ -1,4 +1,7 @@ from .bbox_head import BBoxHead from .convfc_bbox_head import ConvFCBBoxHead, SharedFCBBoxHead +from .double_bbox_head import DoubleConvFCBBoxHead -__all__ = ['BBoxHead', 'ConvFCBBoxHead', 'SharedFCBBoxHead'] +__all__ = [ + 'BBoxHead', 'ConvFCBBoxHead', 'SharedFCBBoxHead', 'DoubleConvFCBBoxHead' +] diff --git a/mmdet/models/bbox_heads/double_bbox_head.py b/mmdet/models/bbox_heads/double_bbox_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7211d9b49373f2c646e6c30c8c02f60b6c8f7a35 --- /dev/null +++ b/mmdet/models/bbox_heads/double_bbox_head.py @@ -0,0 +1,170 @@ +import torch.nn as nn +from mmcv.cnn.weight_init import normal_init, xavier_init + +from .bbox_head import BBoxHead +from ..backbones.resnet import Bottleneck +from ..registry import HEADS +from ..utils import ConvModule + + +class BasicResBlock(nn.Module): + """Basic residual block. + + This block is a little different from the block in the ResNet backbone. + The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock. + + Args: + in_channels (int): Channels of the input feature map. + out_channels (int): Channels of the output feature map. + conv_cfg (dict): The config dict for convolution layers. + norm_cfg (dict): The config dict for normalization layers. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN')): + super(BasicResBlock, self).__init__() + + # main path + self.conv1 = ConvModule( + in_channels, + in_channels, + kernel_size=3, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + self.conv2 = ConvModule( + in_channels, + out_channels, + kernel_size=1, + bias=False, + activation=None, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + # identity path + self.conv_identity = ConvModule( + in_channels, + out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + activation=None) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + identity = x + + x = self.conv1(x) + x = self.conv2(x) + + identity = self.conv_identity(identity) + out = x + identity + + out = self.relu(out) + return out + + +@HEADS.register_module +class DoubleConvFCBBoxHead(BBoxHead): + """Bbox head used in Double-Head R-CNN + + /-> cls + /-> shared convs -> + \-> reg + roi features + /-> cls + \-> shared fc -> + \-> reg + """ # noqa: W605 + + def __init__(self, + num_convs=0, + num_fcs=0, + conv_out_channels=1024, + fc_out_channels=1024, + conv_cfg=None, + norm_cfg=dict(type='BN'), + **kwargs): + kwargs.setdefault('with_avg_pool', True) + super(DoubleConvFCBBoxHead, self).__init__(**kwargs) + assert self.with_avg_pool + assert num_convs > 0 + assert num_fcs > 0 + self.num_convs = num_convs + self.num_fcs = num_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + # increase the channel of input features + self.res_block = BasicResBlock(self.in_channels, + self.conv_out_channels) + + # add conv heads + self.conv_branch = self._add_conv_branch() + # add fc heads + self.fc_branch = self._add_fc_branch() + + out_dim_reg = 4 if self.reg_class_agnostic else 4 * self.num_classes + self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg) + + self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes) + self.relu = nn.ReLU(inplace=True) + + def _add_conv_branch(self): + """Add the fc branch which consists of a sequential of conv layers""" + branch_convs = nn.ModuleList() + for i in range(self.num_convs): + branch_convs.append( + Bottleneck( + inplanes=self.conv_out_channels, + planes=self.conv_out_channels // 4, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + return branch_convs + + def _add_fc_branch(self): + """Add the fc branch which consists of a sequential of fc layers""" + branch_fcs = nn.ModuleList() + for i in range(self.num_fcs): + fc_in_channels = ( + self.in_channels * self.roi_feat_size * + self.roi_feat_size if i == 0 else self.fc_out_channels) + branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels)) + return branch_fcs + + def init_weights(self): + normal_init(self.fc_cls, std=0.01) + normal_init(self.fc_reg, std=0.001) + + for m in self.fc_branch.modules(): + if isinstance(m, nn.Linear): + xavier_init(m, distribution='uniform') + + def forward(self, x_cls, x_reg): + # conv head + x_conv = self.res_block(x_reg) + + for conv in self.conv_branch: + x_conv = conv(x_conv) + + if self.with_avg_pool: + x_conv = self.avg_pool(x_conv) + + x_conv = x_conv.view(x_conv.size(0), -1) + bbox_pred = self.fc_reg(x_conv) + + # fc head + x_fc = x_cls.view(x_cls.size(0), -1) + for fc in self.fc_branch: + x_fc = self.relu(fc(x_fc)) + + cls_score = self.fc_cls(x_fc) + + return cls_score, bbox_pred diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 9917f5d50310ee10fbf5d3be28fd3e5dde109efa..4d97b49e2e93c8dabb4a263bf54c2c5bd5fe66e5 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -6,6 +6,7 @@ from .fast_rcnn import FastRCNN from .faster_rcnn import FasterRCNN from .mask_rcnn import MaskRCNN from .cascade_rcnn import CascadeRCNN +from .double_head_rcnn import DoubleHeadRCNN from .htc import HybridTaskCascade from .retinanet import RetinaNet from .fcos import FCOS @@ -15,5 +16,5 @@ from .mask_scoring_rcnn import MaskScoringRCNN __all__ = [ 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', - 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN' + 'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN' ] diff --git a/mmdet/models/detectors/double_head_rcnn.py b/mmdet/models/detectors/double_head_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..50001872713a3ea57750cda28d1805be8ad434dc --- /dev/null +++ b/mmdet/models/detectors/double_head_rcnn.py @@ -0,0 +1,154 @@ +import torch + +from .two_stage import TwoStageDetector +from ..registry import DETECTORS +from mmdet.core import bbox2roi, build_assigner, build_sampler + + +@DETECTORS.register_module +class DoubleHeadRCNN(TwoStageDetector): + + def __init__(self, reg_roi_scale_factor, **kwargs): + super().__init__(**kwargs) + self.reg_roi_scale_factor = reg_roi_scale_factor + + def forward_train(self, + img, + img_meta, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + gt_masks=None, + proposals=None): + x = self.extract_feat(img) + + losses = dict() + + # RPN forward and loss + if self.with_rpn: + rpn_outs = self.rpn_head(x) + rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, + self.train_cfg.rpn) + rpn_losses = self.rpn_head.loss( + *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + losses.update(rpn_losses) + + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + proposal_inputs = rpn_outs + (img_meta, proposal_cfg) + proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) + else: + proposal_list = proposals + + # assign gts and sample proposals + if self.with_bbox or self.with_mask: + bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner) + bbox_sampler = build_sampler( + self.train_cfg.rcnn.sampler, context=self) + num_imgs = img.size(0) + if gt_bboxes_ignore is None: + gt_bboxes_ignore = [None for _ in range(num_imgs)] + sampling_results = [] + for i in range(num_imgs): + assign_result = bbox_assigner.assign(proposal_list[i], + gt_bboxes[i], + gt_bboxes_ignore[i], + gt_labels[i]) + sampling_result = bbox_sampler.sample( + assign_result, + proposal_list[i], + gt_bboxes[i], + gt_labels[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + # bbox head forward and loss + if self.with_bbox: + rois = bbox2roi([res.bboxes for res in sampling_results]) + # TODO: a more flexible way to decide which feature maps to use + bbox_cls_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + bbox_reg_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], + rois, + roi_scale_factor=self.reg_roi_scale_factor) + if self.with_shared_head: + bbox_cls_feats = self.shared_head(bbox_cls_feats) + bbox_reg_feats = self.shared_head(bbox_reg_feats) + cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, + bbox_reg_feats) + + bbox_targets = self.bbox_head.get_target(sampling_results, + gt_bboxes, gt_labels, + self.train_cfg.rcnn) + loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, + *bbox_targets) + losses.update(loss_bbox) + + # mask head forward and loss + if self.with_mask: + if not self.share_roi_extractor: + pos_rois = bbox2roi( + [res.pos_bboxes for res in sampling_results]) + mask_feats = self.mask_roi_extractor( + x[:self.mask_roi_extractor.num_inputs], pos_rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + else: + pos_inds = [] + device = bbox_cls_feats.device + for res in sampling_results: + pos_inds.append( + torch.ones( + res.pos_bboxes.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds.append( + torch.zeros( + res.neg_bboxes.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds = torch.cat(pos_inds) + mask_feats = bbox_cls_feats[pos_inds] + mask_pred = self.mask_head(mask_feats) + + mask_targets = self.mask_head.get_target(sampling_results, + gt_masks, + self.train_cfg.rcnn) + pos_labels = torch.cat( + [res.pos_gt_labels for res in sampling_results]) + loss_mask = self.mask_head.loss(mask_pred, mask_targets, + pos_labels) + losses.update(loss_mask) + + return losses + + def simple_test_bboxes(self, + x, + img_meta, + proposals, + rcnn_test_cfg, + rescale=False): + """Test only det bboxes without augmentation.""" + rois = bbox2roi(proposals) + bbox_cls_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], rois) + bbox_reg_feats = self.bbox_roi_extractor( + x[:self.bbox_roi_extractor.num_inputs], + rois, + roi_scale_factor=self.reg_roi_scale_factor) + if self.with_shared_head: + bbox_cls_feats = self.shared_head(bbox_cls_feats) + bbox_reg_feats = self.shared_head(bbox_reg_feats) + cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) + img_shape = img_meta[0]['img_shape'] + scale_factor = img_meta[0]['scale_factor'] + det_bboxes, det_labels = self.bbox_head.get_det_bboxes( + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=rescale, + cfg=rcnn_test_cfg) + return det_bboxes, det_labels diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py index 7eaab120e6e1a6a2b1243752b23348f3b0cca713..503818ef368179cb0258645941725c3e52c9d4ac 100644 --- a/mmdet/models/roi_extractors/single_level.py +++ b/mmdet/models/roi_extractors/single_level.py @@ -72,8 +72,22 @@ class SingleRoIExtractor(nn.Module): target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() return target_lvls + def roi_rescale(self, rois, scale_factor): + cx = (rois[:, 1] + rois[:, 3]) * 0.5 + cy = (rois[:, 2] + rois[:, 4]) * 0.5 + w = rois[:, 3] - rois[:, 1] + 1 + h = rois[:, 4] - rois[:, 2] + 1 + new_w = w * scale_factor + new_h = h * scale_factor + x1 = cx - new_w * 0.5 + 0.5 + x2 = cx + new_w * 0.5 - 0.5 + y1 = cy - new_h * 0.5 + 0.5 + y2 = cy + new_h * 0.5 - 0.5 + new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1) + return new_rois + @force_fp32(apply_to=('feats', ), out_fp16=True) - def forward(self, feats, rois): + def forward(self, feats, rois, roi_scale_factor=None): if len(feats) == 1: return self.roi_layers[0](feats[0], rois) @@ -82,6 +96,8 @@ class SingleRoIExtractor(nn.Module): target_lvls = self.map_roi_levels(rois, num_levels) roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels, out_size, out_size) + if roi_scale_factor is not None: + rois = self.roi_rescale(rois, roi_scale_factor) for i in range(num_levels): inds = target_lvls == i if inds.any():