From e746639503c7c81c10f4c789616a21b089e058a9 Mon Sep 17 00:00:00 2001 From: Cao Yuhang <yhcao6@gmail.com> Date: Fri, 28 Jun 2019 16:49:00 +0800 Subject: [PATCH] Support FP16 training (#520) * add fp16 support * fpn donot need bn normalize * refactor wrapped bn * fix bug of retinanet * add fp16 ssd300 voc, cascade r50, cascade mask r50 * fix bug in cascade rcnn testing * add support to fix bn training * add fix bn cfg * delete fixbn cfg, mv fixbn fp16 to a new branch * fix cascade mask fp16 bug in test * fix bug in cascade mask rcnn fp16 test * add more fp16 cfgs * add fp16 fast-r50 and faster-dconv-r50 * add fp16 test, minor fix * clean code * fix config work_dir name * add patch func, refactor code * fix format * clean code * move convert rois to single_level_extractor * fix bug for cascade mask, the seg mask is ndarray * refactor code, add two decorator force_fp32 and auto_fp16 * add fp16_enable attribute * add more comment, fix format and test assertion * fix pep8 format error * format commont and api * rename distribute to distributed, fix dict copy * rename function name * move function, add comment * remove unused parameter * mv decorators into decorators.py, hook related functions to hook * add auto_fp16 to forward of semantic head * add auto_fp16 to all heads and fpn * add docstrings and minor bug fix * simple refactoring * bug fix for patching forward method * roi extractor in fp32 mode * fix flake8 error * fix ci error * add fp16 support to ga head * remove parallel test assert * minor fix * add comment in build_optimizer * fix typo in comment * fix typo enable --> enabling * udpate README --- README.md | 2 +- configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py | 170 ++++++++++++++++ configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py | 185 ++++++++++++++++++ configs/fp16/retinanet_r50_fpn_fp16_1x.py | 127 ++++++++++++ mmdet/apis/train.py | 30 ++- mmdet/core/__init__.py | 3 +- mmdet/core/fp16/__init__.py | 4 + mmdet/core/fp16/decorators.py | 160 +++++++++++++++ mmdet/core/fp16/hooks.py | 126 ++++++++++++ mmdet/core/fp16/utils.py | 23 +++ mmdet/core/utils/dist_utils.py | 7 +- mmdet/models/anchor_heads/anchor_head.py | 5 +- mmdet/models/anchor_heads/fcos_head.py | 8 +- .../models/anchor_heads/guided_anchor_head.py | 65 +++--- mmdet/models/anchor_heads/ssd_head.py | 1 + mmdet/models/backbones/ssd_vgg.py | 8 +- mmdet/models/bbox_heads/bbox_head.py | 9 +- mmdet/models/detectors/base.py | 4 +- mmdet/models/mask_heads/fcn_mask_head.py | 8 +- .../models/mask_heads/fused_semantic_head.py | 4 + mmdet/models/necks/fpn.py | 3 + mmdet/models/roi_extractors/single_level.py | 7 +- mmdet/models/shared_heads/res_layer.py | 3 + tools/test.py | 5 +- 24 files changed, 915 insertions(+), 52 deletions(-) create mode 100644 configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py create mode 100644 configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py create mode 100644 configs/fp16/retinanet_r50_fpn_fp16_1x.py create mode 100644 mmdet/core/fp16/__init__.py create mode 100644 mmdet/core/fp16/decorators.py create mode 100644 mmdet/core/fp16/hooks.py create mode 100644 mmdet/core/fp16/utils.py diff --git a/README.md b/README.md index 7638ab0..b3e425d 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ Other features - [x] Soft-NMS - [x] Generalized Attention - [x] GCNet -- [ ] Mixed Precision (FP16) Training (coming soon) +- [x] Mixed Precision (FP16) Training ## Installation diff --git a/configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py b/configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py new file mode 100644 index 0000000..2e844a2 --- /dev/null +++ b/configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py @@ -0,0 +1,170 @@ +# fp16 settings +fp16 = dict(loss_scale=512.) + +# model settings +model = dict( + type='FasterRCNN', + pretrained='modelzoo://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100) + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05) +) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=False, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/faster_rcnn_r50_fpn_fp16_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py b/configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py new file mode 100644 index 0000000..092978b --- /dev/null +++ b/configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py @@ -0,0 +1,185 @@ +# fp16 settings +fp16 = dict(loss_scale=512.) + +# model settings +model = dict( + type='MaskRCNN', + pretrained='modelzoo://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHead', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=81, + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=81, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)) +test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100, + mask_thr_binary=0.5)) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=True, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +evaluation = dict(interval=1) +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/mask_rcnn_r50_fpn_fp16_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/fp16/retinanet_r50_fpn_fp16_1x.py b/configs/fp16/retinanet_r50_fpn_fp16_1x.py new file mode 100644 index 0000000..8b5ce0c --- /dev/null +++ b/configs/fp16/retinanet_r50_fpn_fp16_1x.py @@ -0,0 +1,127 @@ +# fp16 settings +fp16 = dict(loss_scale=512.) + +# model settings +model = dict( + type='RetinaNet', + pretrained='modelzoo://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs=True, + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=81, + in_channels=256, + stacked_convs=4, + feat_channels=256, + octave_base_scale=4, + scales_per_octave=3, + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[8, 16, 32, 64, 128], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_thr=0.5), + max_per_img=100) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=False, + with_crowd=False, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=False, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + img_scale=(1333, 800), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_crowd=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +device_ids = range(8) +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/retinanet_r50_fpn_fp16_1x' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 5e40dcf..60d6c7d 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -9,7 +9,8 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmdet import datasets from mmdet.core import (DistOptimizerHook, DistEvalmAPHook, - CocoDistEvalRecallHook, CocoDistEvalmAPHook) + CocoDistEvalRecallHook, CocoDistEvalmAPHook, + Fp16OptimizerHook) from mmdet.datasets import build_dataloader from mmdet.models import RPN from .env import get_root_logger @@ -109,10 +110,14 @@ def build_optimizer(model, optimizer_cfg): # set param-wise lr and weight decay params = [] for name, param in model.named_parameters(): + param_group = {'params': [param]} if not param.requires_grad: + # FP16 training needs to copy gradient/weight between master + # weight copy and model weight, it is convenient to keep all + # parameters here to align with model.parameters() + params.append(param_group) continue - param_group = {'params': [param]} # for norm layers, overwrite the weight decay of weight and bias # TODO: obtain the norm layer prefixes dynamically if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name): @@ -142,12 +147,21 @@ def _dist_train(model, dataset, cfg, validate=False): ] # put model on gpus model = MMDistributedDataParallel(model.cuda()) + # build runner optimizer = build_optimizer(model, cfg.optimizer) runner = Runner(model, batch_processor, optimizer, cfg.work_dir, cfg.log_level) + + # fp16 setting + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config, + **fp16_cfg) + else: + optimizer_config = DistOptimizerHook(**cfg.optimizer_config) + # register hooks - optimizer_config = DistOptimizerHook(**cfg.optimizer_config) runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config) runner.register_hook(DistSamplerSeedHook()) @@ -187,11 +201,19 @@ def _non_dist_train(model, dataset, cfg, validate=False): ] # put model on gpus model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() + # build runner optimizer = build_optimizer(model, cfg.optimizer) runner = Runner(model, batch_processor, optimizer, cfg.work_dir, cfg.log_level) - runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, + # fp16 setting + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook( + **cfg.optimizer_config, **fp16_cfg, distributed=False) + else: + optimizer_config = cfg.optimizer_config + runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config) if cfg.resume_from: diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py index d118b14..f8eb6cb 100644 --- a/mmdet/core/__init__.py +++ b/mmdet/core/__init__.py @@ -1,6 +1,7 @@ from .anchor import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403 -from .mask import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403 +from .fp16 import * # noqa: F401, F403 +from .mask import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 diff --git a/mmdet/core/fp16/__init__.py b/mmdet/core/fp16/__init__.py new file mode 100644 index 0000000..cc655b7 --- /dev/null +++ b/mmdet/core/fp16/__init__.py @@ -0,0 +1,4 @@ +from .decorators import auto_fp16, force_fp32 +from .hooks import Fp16OptimizerHook, wrap_fp16_model + +__all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model'] diff --git a/mmdet/core/fp16/decorators.py b/mmdet/core/fp16/decorators.py new file mode 100644 index 0000000..10ffbf8 --- /dev/null +++ b/mmdet/core/fp16/decorators.py @@ -0,0 +1,160 @@ +import functools +from inspect import getfullargspec + +import torch + +from .utils import cast_tensor_type + + +def auto_fp16(apply_to=None, out_fp32=False): + """Decorator to enable fp16 training automatically. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If inputs arguments are fp32 tensors, they will + be converted to fp16 automatically. Arguments other than fp32 tensors are + ignored. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp32 (bool): Whether to convert the output back to fp32. + + :Example: + + class MyModule1(nn.Module) + + # Convert x and y to fp16 + @auto_fp16() + def forward(self, x, y): + pass + + class MyModule2(nn.Module): + + # convert pred to fp16 + @auto_fp16(apply_to=('pred', )) + def do_something(self, pred, others): + pass + """ + + def auto_fp16_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@auto_fp16 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + # NOTE: default args are not taken into consideration + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.float, torch.half)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = {} + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.float, torch.half) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp32: + output = cast_tensor_type(output, torch.half, torch.float) + return output + + return new_func + + return auto_fp16_wrapper + + +def force_fp32(apply_to=None, out_fp16=False): + """Decorator to convert input arguments to fp32 in force. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If there are some inputs that must be processed + in fp32 mode, then this decorator can handle it. If inputs arguments are + fp16 tensors, they will be converted to fp32 automatically. Arguments other + than fp16 tensors are ignored. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp16 (bool): Whether to convert the output back to fp16. + + :Example: + + class MyModule1(nn.Module) + + # Convert x and y to fp32 + @force_fp32() + def loss(self, x, y): + pass + + class MyModule2(nn.Module): + + # convert pred to fp32 + @force_fp32(apply_to=('pred', )) + def post_process(self, pred, others): + pass + """ + + def force_fp32_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@force_fp32 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.half, torch.float)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.half, torch.float) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, torch.half) + return output + + return new_func + + return force_fp32_wrapper diff --git a/mmdet/core/fp16/hooks.py b/mmdet/core/fp16/hooks.py new file mode 100644 index 0000000..b1ab45e --- /dev/null +++ b/mmdet/core/fp16/hooks.py @@ -0,0 +1,126 @@ +import copy +import torch +import torch.nn as nn +from mmcv.runner import OptimizerHook + +from .utils import cast_tensor_type +from ..utils.dist_utils import allreduce_grads + + +class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook. + + The steps of fp16 optimizer is as follows. + 1. Scale the loss value. + 2. BP in the fp16 model. + 2. Copy gradients from fp16 model to fp32 weights. + 3. Update fp32 weights. + 4. Copy updated parameters from fp32 weights to fp16 model. + + Refer to https://arxiv.org/abs/1710.03740 for more details. + + Args: + loss_scale (float): Scale factor multiplied with loss. + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.loss_scale = loss_scale + self.distributed = distributed + + def before_run(self, runner): + # keep a copy of fp32 weights + runner.optimizer.param_groups = copy.deepcopy( + runner.optimizer.param_groups) + # convert model to fp16 + wrap_fp16_model(runner.model) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new(fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + # scale the loss value + scaled_loss = runner.outputs['loss'] * self.loss_scale + scaled_loss.backward() + # copy fp16 grads in the model to fp32 params in the optimizer + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb) + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scale) + if self.grad_clip is not None: + self.clip_grads(fp32_weights) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + + +def wrap_fp16_model(model): + # convert model to fp16 + model.half() + # patch the normalization layers to make it work in fp32 mode + patch_norm_fp32(model) + # set `fp16_enabled` flag + for m in model.modules(): + if hasattr(m, 'fp16_enabled'): + m.fp16_enabled = True + + +def patch_norm_fp32(module): + if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): + module.float() + module.forward = patch_forward_method(module.forward, torch.half, + torch.float) + for child in module.children(): + patch_norm_fp32(child) + return module + + +def patch_forward_method(func, src_type, dst_type, convert_output=True): + """Patch the forward method of a module. + + Args: + func (callable): The original forward method. + src_type (torch.dtype): Type of input arguments to be converted from. + dst_type (torch.dtype): Type of input arguments to be converted to. + convert_output (bool): Whether to convert the output back to src_type. + + Returns: + callable: The patched forward method. + """ + + def new_forward(*args, **kwargs): + output = func(*cast_tensor_type(args, src_type, dst_type), + **cast_tensor_type(kwargs, src_type, dst_type)) + if convert_output: + output = cast_tensor_type(output, dst_type, src_type) + return output + + return new_forward diff --git a/mmdet/core/fp16/utils.py b/mmdet/core/fp16/utils.py new file mode 100644 index 0000000..ce691c7 --- /dev/null +++ b/mmdet/core/fp16/utils.py @@ -0,0 +1,23 @@ +from collections import abc + +import numpy as np +import torch + + +def cast_tensor_type(inputs, src_type, dst_type): + if isinstance(inputs, torch.Tensor): + return inputs.to(dst_type) + elif isinstance(inputs, str): + return inputs + elif isinstance(inputs, np.ndarray): + return inputs + elif isinstance(inputs, abc.Mapping): + return type(inputs)({ + k: cast_tensor_type(v, src_type, dst_type) + for k, v in inputs.items() + }) + elif isinstance(inputs, abc.Iterable): + return type(inputs)( + cast_tensor_type(item, src_type, dst_type) for item in inputs) + else: + return inputs diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py index ec84bb4..51d7e3c 100644 --- a/mmdet/core/utils/dist_utils.py +++ b/mmdet/core/utils/dist_utils.py @@ -28,9 +28,9 @@ def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): tensor.copy_(synced) -def allreduce_grads(model, coalesce=True, bucket_size_mb=-1): +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): grads = [ - param.grad.data for param in model.parameters() + param.grad.data for param in params if param.requires_grad and param.grad is not None ] world_size = dist.get_world_size() @@ -51,7 +51,8 @@ class DistOptimizerHook(OptimizerHook): def after_train_iter(self, runner): runner.optimizer.zero_grad() runner.outputs['loss'].backward() - allreduce_grads(runner.model, self.coalesce, self.bucket_size_mb) + allreduce_grads(runner.model.parameters(), self.coalesce, + self.bucket_size_mb) if self.grad_clip is not None: self.clip_grads(runner.model.parameters()) runner.optimizer.step() diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py index 1883e91..2b8b144 100644 --- a/mmdet/models/anchor_heads/anchor_head.py +++ b/mmdet/models/anchor_heads/anchor_head.py @@ -6,7 +6,7 @@ import torch.nn as nn from mmcv.cnn import normal_init from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox, - multi_apply, multiclass_nms) + multi_apply, multiclass_nms, force_fp32) from ..builder import build_loss from ..registry import HEADS @@ -64,6 +64,7 @@ class AnchorHead(nn.Module): self.cls_out_channels = num_classes self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) + self.fp16_enabled = False self.anchor_generators = [] for anchor_base in self.anchor_base_sizes: @@ -149,6 +150,7 @@ class AnchorHead(nn.Module): avg_factor=num_total_samples) return loss_cls, loss_bbox + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss(self, cls_scores, bbox_preds, @@ -193,6 +195,7 @@ class AnchorHead(nn.Module): cfg=cfg) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py index a5ad9bc..957906d 100644 --- a/mmdet/models/anchor_heads/fcos_head.py +++ b/mmdet/models/anchor_heads/fcos_head.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from mmcv.cnn import normal_init -from mmdet.core import multi_apply, multiclass_nms, distance2bbox +from mmdet.core import multi_apply, multiclass_nms, distance2bbox, force_fp32 from ..builder import build_loss from ..registry import HEADS from ..utils import bias_init_with_prob, Scale, ConvModule @@ -48,6 +48,7 @@ class FCOSHead(nn.Module): self.loss_centerness = build_loss(loss_centerness) self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg + self.fp16_enabled = False self._init_layers() @@ -108,9 +109,11 @@ class FCOSHead(nn.Module): for reg_layer in self.reg_convs: reg_feat = reg_layer(reg_feat) # scale the bbox_pred of different level - bbox_pred = scale(self.fcos_reg(reg_feat)).exp() + # float to avoid overflow when enabling FP16 + bbox_pred = scale(self.fcos_reg(reg_feat)).float().exp() return cls_score, bbox_pred, centerness + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses')) def loss(self, cls_scores, bbox_preds, @@ -183,6 +186,7 @@ class FCOSHead(nn.Module): loss_bbox=loss_bbox, loss_centerness=loss_centerness) + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses')) def get_bboxes(self, cls_scores, bbox_preds, diff --git a/mmdet/models/anchor_heads/guided_anchor_head.py b/mmdet/models/anchor_heads/guided_anchor_head.py index 8b5dc54..c3cc705 100644 --- a/mmdet/models/anchor_heads/guided_anchor_head.py +++ b/mmdet/models/anchor_heads/guided_anchor_head.py @@ -7,7 +7,7 @@ from mmcv.cnn import normal_init from mmdet.core import (AnchorGenerator, anchor_target, anchor_inside_flags, ga_loc_target, ga_shape_target, delta2bbox, - multi_apply, multiclass_nms) + multi_apply, multiclass_nms, force_fp32) from mmdet.ops import DeformConv, MaskedConv2d from ..builder import build_loss from .anchor_head import AnchorHead @@ -93,37 +93,32 @@ class GuidedAnchorHead(AnchorHead): loss_bbox (dict): Config of bbox regression loss. """ - def __init__(self, - num_classes, - in_channels, - feat_channels=256, - octave_base_scale=8, - scales_per_octave=3, - octave_ratios=[0.5, 1.0, 2.0], - anchor_strides=[4, 8, 16, 32, 64], - anchor_base_sizes=None, - anchoring_means=(.0, .0, .0, .0), - anchoring_stds=(1.0, 1.0, 1.0, 1.0), - target_means=(.0, .0, .0, .0), - target_stds=(1.0, 1.0, 1.0, 1.0), - deformable_groups=4, - loc_filter_thr=0.01, - loss_loc=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_shape=dict( - type='BoundedIoULoss', - beta=0.2, - loss_weight=1.0), - loss_cls=dict( - type='CrossEntropyLoss', - use_sigmoid=True, - loss_weight=1.0), - loss_bbox=dict( - type='SmoothL1Loss', beta=1.0, loss_weight=1.0)): + def __init__( + self, + num_classes, + in_channels, + feat_channels=256, + octave_base_scale=8, + scales_per_octave=3, + octave_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + anchor_base_sizes=None, + anchoring_means=(.0, .0, .0, .0), + anchoring_stds=(1.0, 1.0, 1.0, 1.0), + target_means=(.0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0), + deformable_groups=4, + loc_filter_thr=0.01, + loss_loc=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)): super(AnchorHead, self).__init__() self.in_channels = in_channels self.num_classes = num_classes @@ -169,6 +164,8 @@ class GuidedAnchorHead(AnchorHead): self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) + self.fp16_enabled = False + self._init_layers() def _init_layers(self): @@ -392,6 +389,8 @@ class GuidedAnchorHead(AnchorHead): avg_factor=loc_avg_factor) return loss_loc + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds')) def loss(self, cls_scores, bbox_preds, @@ -502,6 +501,8 @@ class GuidedAnchorHead(AnchorHead): loss_shape=losses_shape, loss_loc=losses_loc) + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds')) def get_bboxes(self, cls_scores, bbox_preds, diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py index c74a598..db86c47 100644 --- a/mmdet/models/anchor_heads/ssd_head.py +++ b/mmdet/models/anchor_heads/ssd_head.py @@ -92,6 +92,7 @@ class SSDHead(AnchorHead): self.target_stds = target_stds self.use_sigmoid_cls = False self.cls_focal_loss = False + self.fp16_enabled = False def init_weights(self): for m in self.modules(): diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py index ffce9a9..f7ba8a4 100644 --- a/mmdet/models/backbones/ssd_vgg.py +++ b/mmdet/models/backbones/ssd_vgg.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from mmcv.cnn import (VGG, xavier_init, constant_init, kaiming_init, normal_init) from mmcv.runner import load_checkpoint + from ..registry import BACKBONES @@ -126,5 +127,8 @@ class L2Norm(nn.Module): self.scale = scale def forward(self, x): - norm = x.pow(2).sum(1, keepdim=True).sqrt() + self.eps - return self.weight[None, :, None, None].expand_as(x) * x / norm + # normalization layer convert to FP32 in FP16 training + x_float = x.float() + norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps + return (self.weight[None, :, None, None].float().expand_as(x_float) * + x_float / norm).type_as(x) diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py index f998859..df80570 100644 --- a/mmdet/models/bbox_heads/bbox_head.py +++ b/mmdet/models/bbox_heads/bbox_head.py @@ -2,7 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmdet.core import delta2bbox, multiclass_nms, bbox_target +from mmdet.core import (delta2bbox, multiclass_nms, bbox_target, force_fp32, + auto_fp16) from ..builder import build_loss from ..losses import accuracy from ..registry import HEADS @@ -40,6 +41,7 @@ class BBoxHead(nn.Module): self.target_means = target_means self.target_stds = target_stds self.reg_class_agnostic = reg_class_agnostic + self.fp16_enabled = False self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) @@ -64,6 +66,7 @@ class BBoxHead(nn.Module): nn.init.normal_(self.fc_reg.weight, 0, 0.001) nn.init.constant_(self.fc_reg.bias, 0) + @auto_fp16() def forward(self, x): if self.with_avg_pool: x = self.avg_pool(x) @@ -90,6 +93,7 @@ class BBoxHead(nn.Module): target_stds=self.target_stds) return cls_reg_targets + @force_fp32(apply_to=('cls_score', 'bbox_pred')) def loss(self, cls_score, bbox_pred, @@ -123,6 +127,7 @@ class BBoxHead(nn.Module): reduction_override=reduction_override) return losses + @force_fp32(apply_to=('cls_score', 'bbox_pred')) def get_det_bboxes(self, rois, cls_score, @@ -156,6 +161,7 @@ class BBoxHead(nn.Module): return det_bboxes, det_labels + @force_fp32(apply_to=('bbox_preds', )) def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): """Refine bboxes during training. @@ -196,6 +202,7 @@ class BBoxHead(nn.Module): return bboxes_list + @force_fp32(apply_to=('bbox_pred', )) def regress_by_class(self, rois, label, bbox_pred, img_meta): """Regress the bbox for the predicted class. Used in Cascade R-CNN. diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 311ca90..96fb48e 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -6,7 +6,7 @@ import numpy as np import torch.nn as nn import pycocotools.mask as maskUtils -from mmdet.core import tensor2imgs, get_classes +from mmdet.core import tensor2imgs, get_classes, auto_fp16 class BaseDetector(nn.Module): @@ -16,6 +16,7 @@ class BaseDetector(nn.Module): def __init__(self): super(BaseDetector, self).__init__() + self.fp16_enabled = False @property def with_neck(self): @@ -79,6 +80,7 @@ class BaseDetector(nn.Module): else: return self.aug_test(imgs, img_metas, **kwargs) + @auto_fp16(apply_to=('img', )) def forward(self, img, img_meta, return_loss=True, **kwargs): if return_loss: return self.forward_train(img, img_meta, **kwargs) diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py index 2136fff..af5cee8 100644 --- a/mmdet/models/mask_heads/fcn_mask_head.py +++ b/mmdet/models/mask_heads/fcn_mask_head.py @@ -7,7 +7,7 @@ import torch.nn as nn from ..builder import build_loss from ..registry import HEADS from ..utils import ConvModule -from mmdet.core import mask_target +from mmdet.core import mask_target, force_fp32, auto_fp16 @HEADS.register_module @@ -43,6 +43,7 @@ class FCNMaskHead(nn.Module): self.class_agnostic = class_agnostic self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg + self.fp16_enabled = False self.loss_mask = build_loss(loss_mask) self.convs = nn.ModuleList() @@ -88,6 +89,7 @@ class FCNMaskHead(nn.Module): m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0) + @auto_fp16() def forward(self, x): for conv in self.convs: x = conv(x) @@ -107,6 +109,7 @@ class FCNMaskHead(nn.Module): gt_masks, rcnn_train_cfg) return mask_targets + @force_fp32(apply_to=('mask_pred', )) def loss(self, mask_pred, mask_targets, labels): loss = dict() if self.class_agnostic: @@ -138,6 +141,9 @@ class FCNMaskHead(nn.Module): if isinstance(mask_pred, torch.Tensor): mask_pred = mask_pred.sigmoid().cpu().numpy() assert isinstance(mask_pred, np.ndarray) + # when enabling mixed precision training, mask_pred may be float16 + # numpy array + mask_pred = mask_pred.astype(np.float32) cls_segms = [[] for _ in range(self.num_classes - 1)] bboxes = det_bboxes.cpu().numpy()[:, :4] diff --git a/mmdet/models/mask_heads/fused_semantic_head.py b/mmdet/models/mask_heads/fused_semantic_head.py index 6107423..550e08e 100644 --- a/mmdet/models/mask_heads/fused_semantic_head.py +++ b/mmdet/models/mask_heads/fused_semantic_head.py @@ -2,6 +2,7 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import kaiming_init +from mmdet.core import auto_fp16, force_fp32 from ..registry import HEADS from ..utils import ConvModule @@ -43,6 +44,7 @@ class FusedSemanticHead(nn.Module): self.loss_weight = loss_weight self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg + self.fp16_enabled = False self.lateral_convs = nn.ModuleList() for i in range(self.num_ins): @@ -79,6 +81,7 @@ class FusedSemanticHead(nn.Module): def init_weights(self): kaiming_init(self.conv_logits) + @auto_fp16() def forward(self, feats): x = self.lateral_convs[self.fusion_level](feats[self.fusion_level]) fused_size = tuple(x.shape[-2:]) @@ -95,6 +98,7 @@ class FusedSemanticHead(nn.Module): x = self.conv_embedding(x) return mask_pred, x + @force_fp32(apply_to=('mask_pred',)) def loss(self, mask_pred, labels): labels = labels.squeeze(1).long() loss_semantic_seg = self.criterion(mask_pred, labels) diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py index 6b8c862..d42fb1d 100644 --- a/mmdet/models/necks/fpn.py +++ b/mmdet/models/necks/fpn.py @@ -2,6 +2,7 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import xavier_init +from mmdet.core import auto_fp16 from ..registry import NECKS from ..utils import ConvModule @@ -29,6 +30,7 @@ class FPN(nn.Module): self.num_outs = num_outs self.activation = activation self.relu_before_extra_convs = relu_before_extra_convs + self.fp16_enabled = False if end_level == -1: self.backbone_end_level = self.num_ins @@ -94,6 +96,7 @@ class FPN(nn.Module): if isinstance(m, nn.Conv2d): xavier_init(m, distribution='uniform') + @auto_fp16() def forward(self, inputs): assert len(inputs) == len(self.in_channels) diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py index 32709d5..6a1c1e7 100644 --- a/mmdet/models/roi_extractors/single_level.py +++ b/mmdet/models/roi_extractors/single_level.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from mmdet import ops +from mmdet.core import force_fp32 from ..registry import ROI_EXTRACTORS @@ -31,6 +32,7 @@ class SingleRoIExtractor(nn.Module): self.out_channels = out_channels self.featmap_strides = featmap_strides self.finest_scale = finest_scale + self.fp16_enabled = False @property def num_inputs(self): @@ -70,6 +72,7 @@ class SingleRoIExtractor(nn.Module): target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() return target_lvls + @force_fp32(apply_to=('feats',), out_fp16=True) def forward(self, feats, rois): if len(feats) == 1: return self.roi_layers[0](feats[0], rois) @@ -77,8 +80,8 @@ class SingleRoIExtractor(nn.Module): out_size = self.roi_layers[0].out_size num_levels = len(feats) target_lvls = self.map_roi_levels(rois, num_levels) - roi_feats = torch.cuda.FloatTensor(rois.size()[0], self.out_channels, - out_size, out_size).fill_(0) + roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels, + out_size, out_size) for i in range(num_levels): inds = target_lvls == i if inds.any(): diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py index 743c2ee..cbc77ac 100644 --- a/mmdet/models/shared_heads/res_layer.py +++ b/mmdet/models/shared_heads/res_layer.py @@ -4,6 +4,7 @@ import torch.nn as nn from mmcv.cnn import constant_init, kaiming_init from mmcv.runner import load_checkpoint +from mmdet.core import auto_fp16 from ..backbones import ResNet, make_res_layer from ..registry import SHARED_HEADS @@ -25,6 +26,7 @@ class ResLayer(nn.Module): self.norm_eval = norm_eval self.norm_cfg = norm_cfg self.stage = stage + self.fp16_enabled = False block, stage_blocks = ResNet.arch_settings[depth] stage_block = stage_blocks[stage] planes = 64 * 2**stage @@ -56,6 +58,7 @@ class ResLayer(nn.Module): else: raise TypeError('pretrained must be a str or None') + @auto_fp16() def forward(self, x): res_layer = getattr(self, 'layer{}'.format(self.stage + 1)) out = res_layer(x) diff --git a/tools/test.py b/tools/test.py index df629f3..54f074d 100644 --- a/tools/test.py +++ b/tools/test.py @@ -11,7 +11,7 @@ from mmcv.runner import load_checkpoint, get_dist_info from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmdet.apis import init_dist -from mmdet.core import results2json, coco_eval +from mmdet.core import results2json, coco_eval, wrap_fp16_model from mmdet.datasets import build_dataloader, get_dataset from mmdet.models import build_detector @@ -157,6 +157,9 @@ def main(): # build the model and load checkpoint model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') # old versions did not save class info in checkpoints, this walkaround is # for backward compatibility -- GitLab