From e746639503c7c81c10f4c789616a21b089e058a9 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Fri, 28 Jun 2019 16:49:00 +0800
Subject: [PATCH] Support FP16 training (#520)

* add fp16 support

* fpn donot need bn normalize

* refactor wrapped bn

* fix bug of retinanet

* add fp16 ssd300 voc, cascade r50, cascade mask r50

* fix bug in cascade rcnn testing

* add support to fix bn training

* add fix bn cfg

* delete fixbn cfg, mv fixbn fp16 to a new branch

* fix cascade mask fp16 bug in test

* fix bug in cascade mask rcnn fp16 test

* add more fp16 cfgs

* add fp16 fast-r50 and faster-dconv-r50

* add fp16 test, minor fix

* clean code

* fix config work_dir name

* add patch func, refactor code

* fix format

* clean code

* move convert rois to single_level_extractor

* fix bug for cascade mask, the seg mask is ndarray

* refactor code, add two decorator force_fp32 and auto_fp16

* add fp16_enable attribute

* add more comment, fix format and test assertion

* fix pep8 format error

* format commont and api

* rename distribute to distributed, fix dict copy

* rename function name

* move function, add comment

* remove unused parameter

* mv decorators into decorators.py, hook related functions to hook

* add auto_fp16 to forward of semantic head

* add auto_fp16 to all heads and fpn

* add docstrings and minor bug fix

* simple refactoring

* bug fix for patching forward method

* roi extractor in fp32 mode

* fix flake8 error

* fix ci error

* add fp16 support to ga head

* remove parallel test assert

* minor fix

* add comment in build_optimizer

* fix typo in comment

* fix typo enable --> enabling

* udpate README
---
 README.md                                     |   2 +-
 configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py   | 170 ++++++++++++++++
 configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py     | 185 ++++++++++++++++++
 configs/fp16/retinanet_r50_fpn_fp16_1x.py     | 127 ++++++++++++
 mmdet/apis/train.py                           |  30 ++-
 mmdet/core/__init__.py                        |   3 +-
 mmdet/core/fp16/__init__.py                   |   4 +
 mmdet/core/fp16/decorators.py                 | 160 +++++++++++++++
 mmdet/core/fp16/hooks.py                      | 126 ++++++++++++
 mmdet/core/fp16/utils.py                      |  23 +++
 mmdet/core/utils/dist_utils.py                |   7 +-
 mmdet/models/anchor_heads/anchor_head.py      |   5 +-
 mmdet/models/anchor_heads/fcos_head.py        |   8 +-
 .../models/anchor_heads/guided_anchor_head.py |  65 +++---
 mmdet/models/anchor_heads/ssd_head.py         |   1 +
 mmdet/models/backbones/ssd_vgg.py             |   8 +-
 mmdet/models/bbox_heads/bbox_head.py          |   9 +-
 mmdet/models/detectors/base.py                |   4 +-
 mmdet/models/mask_heads/fcn_mask_head.py      |   8 +-
 .../models/mask_heads/fused_semantic_head.py  |   4 +
 mmdet/models/necks/fpn.py                     |   3 +
 mmdet/models/roi_extractors/single_level.py   |   7 +-
 mmdet/models/shared_heads/res_layer.py        |   3 +
 tools/test.py                                 |   5 +-
 24 files changed, 915 insertions(+), 52 deletions(-)
 create mode 100644 configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py
 create mode 100644 configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py
 create mode 100644 configs/fp16/retinanet_r50_fpn_fp16_1x.py
 create mode 100644 mmdet/core/fp16/__init__.py
 create mode 100644 mmdet/core/fp16/decorators.py
 create mode 100644 mmdet/core/fp16/hooks.py
 create mode 100644 mmdet/core/fp16/utils.py

diff --git a/README.md b/README.md
index 7638ab0..b3e425d 100644
--- a/README.md
+++ b/README.md
@@ -106,7 +106,7 @@ Other features
 - [x] Soft-NMS
 - [x] Generalized Attention
 - [x] GCNet
-- [ ] Mixed Precision (FP16) Training (coming soon)
+- [x] Mixed Precision (FP16) Training
 
 
 ## Installation
diff --git a/configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py b/configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py
new file mode 100644
index 0000000..2e844a2
--- /dev/null
+++ b/configs/fp16/faster_rcnn_r50_fpn_fp16_1x.py
@@ -0,0 +1,170 @@
+# fp16 settings
+fp16 = dict(loss_scale=512.)
+
+# model settings
+model = dict(
+    type='FasterRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
+    # soft-nms is also supported for rcnn testing
+    # e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
+)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/faster_rcnn_r50_fpn_fp16_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py b/configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py
new file mode 100644
index 0000000..092978b
--- /dev/null
+++ b/configs/fp16/mask_rcnn_r50_fpn_fp16_1x.py
@@ -0,0 +1,185 @@
+# fp16 settings
+fp16 = dict(loss_scale=512.)
+
+# model settings
+model = dict(
+    type='MaskRCNN',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=5),
+    rpn_head=dict(
+        type='RPNHead',
+        in_channels=256,
+        feat_channels=256,
+        anchor_scales=[8],
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[4, 8, 16, 32, 64],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+    bbox_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    bbox_head=dict(
+        type='SharedFCBBoxHead',
+        num_fcs=2,
+        in_channels=256,
+        fc_out_channels=1024,
+        roi_feat_size=7,
+        num_classes=81,
+        target_means=[0., 0., 0., 0.],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        reg_class_agnostic=False,
+        loss_cls=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
+    mask_roi_extractor=dict(
+        type='SingleRoIExtractor',
+        roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
+        out_channels=256,
+        featmap_strides=[4, 8, 16, 32]),
+    mask_head=dict(
+        type='FCNMaskHead',
+        num_convs=4,
+        in_channels=256,
+        conv_out_channels=256,
+        num_classes=81,
+        loss_mask=dict(
+            type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
+# model training and testing settings
+train_cfg = dict(
+    rpn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.7,
+            neg_iou_thr=0.3,
+            min_pos_iou=0.3,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=256,
+            pos_fraction=0.5,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=False),
+        allowed_border=0,
+        pos_weight=-1,
+        debug=False),
+    rpn_proposal=dict(
+        nms_across_levels=False,
+        nms_pre=2000,
+        nms_post=2000,
+        max_num=2000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        assigner=dict(
+            type='MaxIoUAssigner',
+            pos_iou_thr=0.5,
+            neg_iou_thr=0.5,
+            min_pos_iou=0.5,
+            ignore_iof_thr=-1),
+        sampler=dict(
+            type='RandomSampler',
+            num=512,
+            pos_fraction=0.25,
+            neg_pos_ub=-1,
+            add_gt_as_proposals=True),
+        mask_size=28,
+        pos_weight=-1,
+        debug=False))
+test_cfg = dict(
+    rpn=dict(
+        nms_across_levels=False,
+        nms_pre=1000,
+        nms_post=1000,
+        max_num=1000,
+        nms_thr=0.7,
+        min_bbox_size=0),
+    rcnn=dict(
+        score_thr=0.05,
+        nms=dict(type='nms', iou_thr=0.5),
+        max_per_img=100,
+        mask_thr_binary=0.5))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=True,
+        with_crowd=True,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+evaluation = dict(interval=1)
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/mask_rcnn_r50_fpn_fp16_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/fp16/retinanet_r50_fpn_fp16_1x.py b/configs/fp16/retinanet_r50_fpn_fp16_1x.py
new file mode 100644
index 0000000..8b5ce0c
--- /dev/null
+++ b/configs/fp16/retinanet_r50_fpn_fp16_1x.py
@@ -0,0 +1,127 @@
+# fp16 settings
+fp16 = dict(loss_scale=512.)
+
+# model settings
+model = dict(
+    type='RetinaNet',
+    pretrained='modelzoo://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        start_level=1,
+        add_extra_convs=True,
+        num_outs=5),
+    bbox_head=dict(
+        type='RetinaHead',
+        num_classes=81,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=4,
+        scales_per_octave=3,
+        anchor_ratios=[0.5, 1.0, 2.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[1.0, 1.0, 1.0, 1.0],
+        loss_cls=dict(
+            type='FocalLoss',
+            use_sigmoid=True,
+            gamma=2.0,
+            alpha=0.25,
+            loss_weight=1.0),
+        loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(
+        type='MaxIoUAssigner',
+        pos_iou_thr=0.5,
+        neg_iou_thr=0.4,
+        min_pos_iou=0,
+        ignore_iof_thr=-1),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.5),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0.5,
+        with_mask=False,
+        with_crowd=False,
+        with_label=True),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=False,
+        with_label=True),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        img_scale=(1333, 800),
+        img_norm_cfg=img_norm_cfg,
+        size_divisor=32,
+        flip_ratio=0,
+        with_mask=False,
+        with_crowd=False,
+        with_label=False,
+        test_mode=True))
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+device_ids = range(8)
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/retinanet_r50_fpn_fp16_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 5e40dcf..60d6c7d 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -9,7 +9,8 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
 
 from mmdet import datasets
 from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
-                        CocoDistEvalRecallHook, CocoDistEvalmAPHook)
+                        CocoDistEvalRecallHook, CocoDistEvalmAPHook,
+                        Fp16OptimizerHook)
 from mmdet.datasets import build_dataloader
 from mmdet.models import RPN
 from .env import get_root_logger
@@ -109,10 +110,14 @@ def build_optimizer(model, optimizer_cfg):
         # set param-wise lr and weight decay
         params = []
         for name, param in model.named_parameters():
+            param_group = {'params': [param]}
             if not param.requires_grad:
+                # FP16 training needs to copy gradient/weight between master
+                # weight copy and model weight, it is convenient to keep all
+                # parameters here to align with model.parameters()
+                params.append(param_group)
                 continue
 
-            param_group = {'params': [param]}
             # for norm layers, overwrite the weight decay of weight and bias
             # TODO: obtain the norm layer prefixes dynamically
             if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name):
@@ -142,12 +147,21 @@ def _dist_train(model, dataset, cfg, validate=False):
     ]
     # put model on gpus
     model = MMDistributedDataParallel(model.cuda())
+
     # build runner
     optimizer = build_optimizer(model, cfg.optimizer)
     runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                     cfg.log_level)
+
+    # fp16 setting
+    fp16_cfg = cfg.get('fp16', None)
+    if fp16_cfg is not None:
+        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
+                                             **fp16_cfg)
+    else:
+        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
+
     # register hooks
-    optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
     runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                    cfg.checkpoint_config, cfg.log_config)
     runner.register_hook(DistSamplerSeedHook())
@@ -187,11 +201,19 @@ def _non_dist_train(model, dataset, cfg, validate=False):
     ]
     # put model on gpus
     model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
+
     # build runner
     optimizer = build_optimizer(model, cfg.optimizer)
     runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                     cfg.log_level)
-    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
+    # fp16 setting
+    fp16_cfg = cfg.get('fp16', None)
+    if fp16_cfg is not None:
+        optimizer_config = Fp16OptimizerHook(
+            **cfg.optimizer_config, **fp16_cfg, distributed=False)
+    else:
+        optimizer_config = cfg.optimizer_config
+    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                    cfg.checkpoint_config, cfg.log_config)
 
     if cfg.resume_from:
diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py
index d118b14..f8eb6cb 100644
--- a/mmdet/core/__init__.py
+++ b/mmdet/core/__init__.py
@@ -1,6 +1,7 @@
 from .anchor import *  # noqa: F401, F403
 from .bbox import *  # noqa: F401, F403
-from .mask import *  # noqa: F401, F403
 from .evaluation import *  # noqa: F401, F403
+from .fp16 import *  # noqa: F401, F403
+from .mask import *  # noqa: F401, F403
 from .post_processing import *  # noqa: F401, F403
 from .utils import *  # noqa: F401, F403
diff --git a/mmdet/core/fp16/__init__.py b/mmdet/core/fp16/__init__.py
new file mode 100644
index 0000000..cc655b7
--- /dev/null
+++ b/mmdet/core/fp16/__init__.py
@@ -0,0 +1,4 @@
+from .decorators import auto_fp16, force_fp32
+from .hooks import Fp16OptimizerHook, wrap_fp16_model
+
+__all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model']
diff --git a/mmdet/core/fp16/decorators.py b/mmdet/core/fp16/decorators.py
new file mode 100644
index 0000000..10ffbf8
--- /dev/null
+++ b/mmdet/core/fp16/decorators.py
@@ -0,0 +1,160 @@
+import functools
+from inspect import getfullargspec
+
+import torch
+
+from .utils import cast_tensor_type
+
+
+def auto_fp16(apply_to=None, out_fp32=False):
+    """Decorator to enable fp16 training automatically.
+
+    This decorator is useful when you write custom modules and want to support
+    mixed precision training. If inputs arguments are fp32 tensors, they will
+    be converted to fp16 automatically. Arguments other than fp32 tensors are
+    ignored.
+
+    Args:
+        apply_to (Iterable, optional): The argument names to be converted.
+            `None` indicates all arguments.
+        out_fp32 (bool): Whether to convert the output back to fp32.
+
+    :Example:
+
+        class MyModule1(nn.Module)
+
+            # Convert x and y to fp16
+            @auto_fp16()
+            def forward(self, x, y):
+                pass
+
+        class MyModule2(nn.Module):
+
+            # convert pred to fp16
+            @auto_fp16(apply_to=('pred', ))
+            def do_something(self, pred, others):
+                pass
+    """
+
+    def auto_fp16_wrapper(old_func):
+
+        @functools.wraps(old_func)
+        def new_func(*args, **kwargs):
+            # check if the module has set the attribute `fp16_enabled`, if not,
+            # just fallback to the original method.
+            if not isinstance(args[0], torch.nn.Module):
+                raise TypeError('@auto_fp16 can only be used to decorate the '
+                                'method of nn.Module')
+            if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+                return old_func(*args, **kwargs)
+            # get the arg spec of the decorated method
+            args_info = getfullargspec(old_func)
+            # get the argument names to be casted
+            args_to_cast = args_info.args if apply_to is None else apply_to
+            # convert the args that need to be processed
+            new_args = []
+            # NOTE: default args are not taken into consideration
+            if args:
+                arg_names = args_info.args[:len(args)]
+                for i, arg_name in enumerate(arg_names):
+                    if arg_name in args_to_cast:
+                        new_args.append(
+                            cast_tensor_type(args[i], torch.float, torch.half))
+                    else:
+                        new_args.append(args[i])
+            # convert the kwargs that need to be processed
+            new_kwargs = {}
+            if kwargs:
+                for arg_name, arg_value in kwargs.items():
+                    if arg_name in args_to_cast:
+                        new_kwargs[arg_name] = cast_tensor_type(
+                            arg_value, torch.float, torch.half)
+                    else:
+                        new_kwargs[arg_name] = arg_value
+            # apply converted arguments to the decorated method
+            output = old_func(*new_args, **new_kwargs)
+            # cast the results back to fp32 if necessary
+            if out_fp32:
+                output = cast_tensor_type(output, torch.half, torch.float)
+            return output
+
+        return new_func
+
+    return auto_fp16_wrapper
+
+
+def force_fp32(apply_to=None, out_fp16=False):
+    """Decorator to convert input arguments to fp32 in force.
+
+    This decorator is useful when you write custom modules and want to support
+    mixed precision training. If there are some inputs that must be processed
+    in fp32 mode, then this decorator can handle it. If inputs arguments are
+    fp16 tensors, they will be converted to fp32 automatically. Arguments other
+    than fp16 tensors are ignored.
+
+    Args:
+        apply_to (Iterable, optional): The argument names to be converted.
+            `None` indicates all arguments.
+        out_fp16 (bool): Whether to convert the output back to fp16.
+
+    :Example:
+
+        class MyModule1(nn.Module)
+
+            # Convert x and y to fp32
+            @force_fp32()
+            def loss(self, x, y):
+                pass
+
+        class MyModule2(nn.Module):
+
+            # convert pred to fp32
+            @force_fp32(apply_to=('pred', ))
+            def post_process(self, pred, others):
+                pass
+    """
+
+    def force_fp32_wrapper(old_func):
+
+        @functools.wraps(old_func)
+        def new_func(*args, **kwargs):
+            # check if the module has set the attribute `fp16_enabled`, if not,
+            # just fallback to the original method.
+            if not isinstance(args[0], torch.nn.Module):
+                raise TypeError('@force_fp32 can only be used to decorate the '
+                                'method of nn.Module')
+            if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+                return old_func(*args, **kwargs)
+            # get the arg spec of the decorated method
+            args_info = getfullargspec(old_func)
+            # get the argument names to be casted
+            args_to_cast = args_info.args if apply_to is None else apply_to
+            # convert the args that need to be processed
+            new_args = []
+            if args:
+                arg_names = args_info.args[:len(args)]
+                for i, arg_name in enumerate(arg_names):
+                    if arg_name in args_to_cast:
+                        new_args.append(
+                            cast_tensor_type(args[i], torch.half, torch.float))
+                    else:
+                        new_args.append(args[i])
+            # convert the kwargs that need to be processed
+            new_kwargs = dict()
+            if kwargs:
+                for arg_name, arg_value in kwargs.items():
+                    if arg_name in args_to_cast:
+                        new_kwargs[arg_name] = cast_tensor_type(
+                            arg_value, torch.half, torch.float)
+                    else:
+                        new_kwargs[arg_name] = arg_value
+            # apply converted arguments to the decorated method
+            output = old_func(*new_args, **new_kwargs)
+            # cast the results back to fp32 if necessary
+            if out_fp16:
+                output = cast_tensor_type(output, torch.float, torch.half)
+            return output
+
+        return new_func
+
+    return force_fp32_wrapper
diff --git a/mmdet/core/fp16/hooks.py b/mmdet/core/fp16/hooks.py
new file mode 100644
index 0000000..b1ab45e
--- /dev/null
+++ b/mmdet/core/fp16/hooks.py
@@ -0,0 +1,126 @@
+import copy
+import torch
+import torch.nn as nn
+from mmcv.runner import OptimizerHook
+
+from .utils import cast_tensor_type
+from ..utils.dist_utils import allreduce_grads
+
+
+class Fp16OptimizerHook(OptimizerHook):
+    """FP16 optimizer hook.
+
+    The steps of fp16 optimizer is as follows.
+    1. Scale the loss value.
+    2. BP in the fp16 model.
+    2. Copy gradients from fp16 model to fp32 weights.
+    3. Update fp32 weights.
+    4. Copy updated parameters from fp32 weights to fp16 model.
+
+    Refer to https://arxiv.org/abs/1710.03740 for more details.
+
+    Args:
+        loss_scale (float): Scale factor multiplied with loss.
+    """
+
+    def __init__(self,
+                 grad_clip=None,
+                 coalesce=True,
+                 bucket_size_mb=-1,
+                 loss_scale=512.,
+                 distributed=True):
+        self.grad_clip = grad_clip
+        self.coalesce = coalesce
+        self.bucket_size_mb = bucket_size_mb
+        self.loss_scale = loss_scale
+        self.distributed = distributed
+
+    def before_run(self, runner):
+        # keep a copy of fp32 weights
+        runner.optimizer.param_groups = copy.deepcopy(
+            runner.optimizer.param_groups)
+        # convert model to fp16
+        wrap_fp16_model(runner.model)
+
+    def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+        """Copy gradients from fp16 model to fp32 weight copy."""
+        for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()):
+            if fp16_param.grad is not None:
+                if fp32_param.grad is None:
+                    fp32_param.grad = fp32_param.data.new(fp32_param.size())
+                fp32_param.grad.copy_(fp16_param.grad)
+
+    def copy_params_to_fp16(self, fp16_net, fp32_weights):
+        """Copy updated params from fp32 weight copy to fp16 model."""
+        for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights):
+            fp16_param.data.copy_(fp32_param.data)
+
+    def after_train_iter(self, runner):
+        # clear grads of last iteration
+        runner.model.zero_grad()
+        runner.optimizer.zero_grad()
+        # scale the loss value
+        scaled_loss = runner.outputs['loss'] * self.loss_scale
+        scaled_loss.backward()
+        # copy fp16 grads in the model to fp32 params in the optimizer
+        fp32_weights = []
+        for param_group in runner.optimizer.param_groups:
+            fp32_weights += param_group['params']
+        self.copy_grads_to_fp32(runner.model, fp32_weights)
+        # allreduce grads
+        if self.distributed:
+            allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb)
+        # scale the gradients back
+        for param in fp32_weights:
+            if param.grad is not None:
+                param.grad.div_(self.loss_scale)
+        if self.grad_clip is not None:
+            self.clip_grads(fp32_weights)
+        # update fp32 params
+        runner.optimizer.step()
+        # copy fp32 params to the fp16 model
+        self.copy_params_to_fp16(runner.model, fp32_weights)
+
+
+def wrap_fp16_model(model):
+    # convert model to fp16
+    model.half()
+    # patch the normalization layers to make it work in fp32 mode
+    patch_norm_fp32(model)
+    # set `fp16_enabled` flag
+    for m in model.modules():
+        if hasattr(m, 'fp16_enabled'):
+            m.fp16_enabled = True
+
+
+def patch_norm_fp32(module):
+    if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
+        module.float()
+        module.forward = patch_forward_method(module.forward, torch.half,
+                                              torch.float)
+    for child in module.children():
+        patch_norm_fp32(child)
+    return module
+
+
+def patch_forward_method(func, src_type, dst_type, convert_output=True):
+    """Patch the forward method of a module.
+
+    Args:
+        func (callable): The original forward method.
+        src_type (torch.dtype): Type of input arguments to be converted from.
+        dst_type (torch.dtype): Type of input arguments to be converted to.
+        convert_output (bool): Whether to convert the output back to src_type.
+
+    Returns:
+        callable: The patched forward method.
+    """
+
+    def new_forward(*args, **kwargs):
+        output = func(*cast_tensor_type(args, src_type, dst_type),
+                      **cast_tensor_type(kwargs, src_type, dst_type))
+        if convert_output:
+            output = cast_tensor_type(output, dst_type, src_type)
+        return output
+
+    return new_forward
diff --git a/mmdet/core/fp16/utils.py b/mmdet/core/fp16/utils.py
new file mode 100644
index 0000000..ce691c7
--- /dev/null
+++ b/mmdet/core/fp16/utils.py
@@ -0,0 +1,23 @@
+from collections import abc
+
+import numpy as np
+import torch
+
+
+def cast_tensor_type(inputs, src_type, dst_type):
+    if isinstance(inputs, torch.Tensor):
+        return inputs.to(dst_type)
+    elif isinstance(inputs, str):
+        return inputs
+    elif isinstance(inputs, np.ndarray):
+        return inputs
+    elif isinstance(inputs, abc.Mapping):
+        return type(inputs)({
+            k: cast_tensor_type(v, src_type, dst_type)
+            for k, v in inputs.items()
+        })
+    elif isinstance(inputs, abc.Iterable):
+        return type(inputs)(
+            cast_tensor_type(item, src_type, dst_type) for item in inputs)
+    else:
+        return inputs
diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py
index ec84bb4..51d7e3c 100644
--- a/mmdet/core/utils/dist_utils.py
+++ b/mmdet/core/utils/dist_utils.py
@@ -28,9 +28,9 @@ def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
             tensor.copy_(synced)
 
 
-def allreduce_grads(model, coalesce=True, bucket_size_mb=-1):
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
     grads = [
-        param.grad.data for param in model.parameters()
+        param.grad.data for param in params
         if param.requires_grad and param.grad is not None
     ]
     world_size = dist.get_world_size()
@@ -51,7 +51,8 @@ class DistOptimizerHook(OptimizerHook):
     def after_train_iter(self, runner):
         runner.optimizer.zero_grad()
         runner.outputs['loss'].backward()
-        allreduce_grads(runner.model, self.coalesce, self.bucket_size_mb)
+        allreduce_grads(runner.model.parameters(), self.coalesce,
+                        self.bucket_size_mb)
         if self.grad_clip is not None:
             self.clip_grads(runner.model.parameters())
         runner.optimizer.step()
diff --git a/mmdet/models/anchor_heads/anchor_head.py b/mmdet/models/anchor_heads/anchor_head.py
index 1883e91..2b8b144 100644
--- a/mmdet/models/anchor_heads/anchor_head.py
+++ b/mmdet/models/anchor_heads/anchor_head.py
@@ -6,7 +6,7 @@ import torch.nn as nn
 from mmcv.cnn import normal_init
 
 from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
-                        multi_apply, multiclass_nms)
+                        multi_apply, multiclass_nms, force_fp32)
 from ..builder import build_loss
 from ..registry import HEADS
 
@@ -64,6 +64,7 @@ class AnchorHead(nn.Module):
             self.cls_out_channels = num_classes
         self.loss_cls = build_loss(loss_cls)
         self.loss_bbox = build_loss(loss_bbox)
+        self.fp16_enabled = False
 
         self.anchor_generators = []
         for anchor_base in self.anchor_base_sizes:
@@ -149,6 +150,7 @@ class AnchorHead(nn.Module):
             avg_factor=num_total_samples)
         return loss_cls, loss_bbox
 
+    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
     def loss(self,
              cls_scores,
              bbox_preds,
@@ -193,6 +195,7 @@ class AnchorHead(nn.Module):
             cfg=cfg)
         return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
 
+    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
     def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
                    rescale=False):
         assert len(cls_scores) == len(bbox_preds)
diff --git a/mmdet/models/anchor_heads/fcos_head.py b/mmdet/models/anchor_heads/fcos_head.py
index a5ad9bc..957906d 100644
--- a/mmdet/models/anchor_heads/fcos_head.py
+++ b/mmdet/models/anchor_heads/fcos_head.py
@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 from mmcv.cnn import normal_init
 
-from mmdet.core import multi_apply, multiclass_nms, distance2bbox
+from mmdet.core import multi_apply, multiclass_nms, distance2bbox, force_fp32
 from ..builder import build_loss
 from ..registry import HEADS
 from ..utils import bias_init_with_prob, Scale, ConvModule
@@ -48,6 +48,7 @@ class FCOSHead(nn.Module):
         self.loss_centerness = build_loss(loss_centerness)
         self.conv_cfg = conv_cfg
         self.norm_cfg = norm_cfg
+        self.fp16_enabled = False
 
         self._init_layers()
 
@@ -108,9 +109,11 @@ class FCOSHead(nn.Module):
         for reg_layer in self.reg_convs:
             reg_feat = reg_layer(reg_feat)
         # scale the bbox_pred of different level
-        bbox_pred = scale(self.fcos_reg(reg_feat)).exp()
+        # float to avoid overflow when enabling FP16
+        bbox_pred = scale(self.fcos_reg(reg_feat)).float().exp()
         return cls_score, bbox_pred, centerness
 
+    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
     def loss(self,
              cls_scores,
              bbox_preds,
@@ -183,6 +186,7 @@ class FCOSHead(nn.Module):
             loss_bbox=loss_bbox,
             loss_centerness=loss_centerness)
 
+    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
     def get_bboxes(self,
                    cls_scores,
                    bbox_preds,
diff --git a/mmdet/models/anchor_heads/guided_anchor_head.py b/mmdet/models/anchor_heads/guided_anchor_head.py
index 8b5dc54..c3cc705 100644
--- a/mmdet/models/anchor_heads/guided_anchor_head.py
+++ b/mmdet/models/anchor_heads/guided_anchor_head.py
@@ -7,7 +7,7 @@ from mmcv.cnn import normal_init
 
 from mmdet.core import (AnchorGenerator, anchor_target, anchor_inside_flags,
                         ga_loc_target, ga_shape_target, delta2bbox,
-                        multi_apply, multiclass_nms)
+                        multi_apply, multiclass_nms, force_fp32)
 from mmdet.ops import DeformConv, MaskedConv2d
 from ..builder import build_loss
 from .anchor_head import AnchorHead
@@ -93,37 +93,32 @@ class GuidedAnchorHead(AnchorHead):
         loss_bbox (dict): Config of bbox regression loss.
     """
 
-    def __init__(self,
-                 num_classes,
-                 in_channels,
-                 feat_channels=256,
-                 octave_base_scale=8,
-                 scales_per_octave=3,
-                 octave_ratios=[0.5, 1.0, 2.0],
-                 anchor_strides=[4, 8, 16, 32, 64],
-                 anchor_base_sizes=None,
-                 anchoring_means=(.0, .0, .0, .0),
-                 anchoring_stds=(1.0, 1.0, 1.0, 1.0),
-                 target_means=(.0, .0, .0, .0),
-                 target_stds=(1.0, 1.0, 1.0, 1.0),
-                 deformable_groups=4,
-                 loc_filter_thr=0.01,
-                 loss_loc=dict(
-                     type='FocalLoss',
-                     use_sigmoid=True,
-                     gamma=2.0,
-                     alpha=0.25,
-                     loss_weight=1.0),
-                 loss_shape=dict(
-                     type='BoundedIoULoss',
-                     beta=0.2,
-                     loss_weight=1.0),
-                 loss_cls=dict(
-                     type='CrossEntropyLoss',
-                     use_sigmoid=True,
-                     loss_weight=1.0),
-                 loss_bbox=dict(
-                     type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
+    def __init__(
+            self,
+            num_classes,
+            in_channels,
+            feat_channels=256,
+            octave_base_scale=8,
+            scales_per_octave=3,
+            octave_ratios=[0.5, 1.0, 2.0],
+            anchor_strides=[4, 8, 16, 32, 64],
+            anchor_base_sizes=None,
+            anchoring_means=(.0, .0, .0, .0),
+            anchoring_stds=(1.0, 1.0, 1.0, 1.0),
+            target_means=(.0, .0, .0, .0),
+            target_stds=(1.0, 1.0, 1.0, 1.0),
+            deformable_groups=4,
+            loc_filter_thr=0.01,
+            loss_loc=dict(
+                type='FocalLoss',
+                use_sigmoid=True,
+                gamma=2.0,
+                alpha=0.25,
+                loss_weight=1.0),
+            loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
+            loss_cls=dict(
+                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
         super(AnchorHead, self).__init__()
         self.in_channels = in_channels
         self.num_classes = num_classes
@@ -169,6 +164,8 @@ class GuidedAnchorHead(AnchorHead):
         self.loss_cls = build_loss(loss_cls)
         self.loss_bbox = build_loss(loss_bbox)
 
+        self.fp16_enabled = False
+
         self._init_layers()
 
     def _init_layers(self):
@@ -392,6 +389,8 @@ class GuidedAnchorHead(AnchorHead):
             avg_factor=loc_avg_factor)
         return loss_loc
 
+    @force_fp32(
+        apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
     def loss(self,
              cls_scores,
              bbox_preds,
@@ -502,6 +501,8 @@ class GuidedAnchorHead(AnchorHead):
             loss_shape=losses_shape,
             loss_loc=losses_loc)
 
+    @force_fp32(
+        apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
     def get_bboxes(self,
                    cls_scores,
                    bbox_preds,
diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py
index c74a598..db86c47 100644
--- a/mmdet/models/anchor_heads/ssd_head.py
+++ b/mmdet/models/anchor_heads/ssd_head.py
@@ -92,6 +92,7 @@ class SSDHead(AnchorHead):
         self.target_stds = target_stds
         self.use_sigmoid_cls = False
         self.cls_focal_loss = False
+        self.fp16_enabled = False
 
     def init_weights(self):
         for m in self.modules():
diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py
index ffce9a9..f7ba8a4 100644
--- a/mmdet/models/backbones/ssd_vgg.py
+++ b/mmdet/models/backbones/ssd_vgg.py
@@ -6,6 +6,7 @@ import torch.nn.functional as F
 from mmcv.cnn import (VGG, xavier_init, constant_init, kaiming_init,
                       normal_init)
 from mmcv.runner import load_checkpoint
+
 from ..registry import BACKBONES
 
 
@@ -126,5 +127,8 @@ class L2Norm(nn.Module):
         self.scale = scale
 
     def forward(self, x):
-        norm = x.pow(2).sum(1, keepdim=True).sqrt() + self.eps
-        return self.weight[None, :, None, None].expand_as(x) * x / norm
+        # normalization layer convert to FP32 in FP16 training
+        x_float = x.float()
+        norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
+        return (self.weight[None, :, None, None].float().expand_as(x_float) *
+                x_float / norm).type_as(x)
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index f998859..df80570 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -2,7 +2,8 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from mmdet.core import delta2bbox, multiclass_nms, bbox_target
+from mmdet.core import (delta2bbox, multiclass_nms, bbox_target, force_fp32,
+                        auto_fp16)
 from ..builder import build_loss
 from ..losses import accuracy
 from ..registry import HEADS
@@ -40,6 +41,7 @@ class BBoxHead(nn.Module):
         self.target_means = target_means
         self.target_stds = target_stds
         self.reg_class_agnostic = reg_class_agnostic
+        self.fp16_enabled = False
 
         self.loss_cls = build_loss(loss_cls)
         self.loss_bbox = build_loss(loss_bbox)
@@ -64,6 +66,7 @@ class BBoxHead(nn.Module):
             nn.init.normal_(self.fc_reg.weight, 0, 0.001)
             nn.init.constant_(self.fc_reg.bias, 0)
 
+    @auto_fp16()
     def forward(self, x):
         if self.with_avg_pool:
             x = self.avg_pool(x)
@@ -90,6 +93,7 @@ class BBoxHead(nn.Module):
             target_stds=self.target_stds)
         return cls_reg_targets
 
+    @force_fp32(apply_to=('cls_score', 'bbox_pred'))
     def loss(self,
              cls_score,
              bbox_pred,
@@ -123,6 +127,7 @@ class BBoxHead(nn.Module):
                 reduction_override=reduction_override)
         return losses
 
+    @force_fp32(apply_to=('cls_score', 'bbox_pred'))
     def get_det_bboxes(self,
                        rois,
                        cls_score,
@@ -156,6 +161,7 @@ class BBoxHead(nn.Module):
 
             return det_bboxes, det_labels
 
+    @force_fp32(apply_to=('bbox_preds', ))
     def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
         """Refine bboxes during training.
 
@@ -196,6 +202,7 @@ class BBoxHead(nn.Module):
 
         return bboxes_list
 
+    @force_fp32(apply_to=('bbox_pred', ))
     def regress_by_class(self, rois, label, bbox_pred, img_meta):
         """Regress the bbox for the predicted class. Used in Cascade R-CNN.
 
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 311ca90..96fb48e 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -6,7 +6,7 @@ import numpy as np
 import torch.nn as nn
 import pycocotools.mask as maskUtils
 
-from mmdet.core import tensor2imgs, get_classes
+from mmdet.core import tensor2imgs, get_classes, auto_fp16
 
 
 class BaseDetector(nn.Module):
@@ -16,6 +16,7 @@ class BaseDetector(nn.Module):
 
     def __init__(self):
         super(BaseDetector, self).__init__()
+        self.fp16_enabled = False
 
     @property
     def with_neck(self):
@@ -79,6 +80,7 @@ class BaseDetector(nn.Module):
         else:
             return self.aug_test(imgs, img_metas, **kwargs)
 
+    @auto_fp16(apply_to=('img', ))
     def forward(self, img, img_meta, return_loss=True, **kwargs):
         if return_loss:
             return self.forward_train(img, img_meta, **kwargs)
diff --git a/mmdet/models/mask_heads/fcn_mask_head.py b/mmdet/models/mask_heads/fcn_mask_head.py
index 2136fff..af5cee8 100644
--- a/mmdet/models/mask_heads/fcn_mask_head.py
+++ b/mmdet/models/mask_heads/fcn_mask_head.py
@@ -7,7 +7,7 @@ import torch.nn as nn
 from ..builder import build_loss
 from ..registry import HEADS
 from ..utils import ConvModule
-from mmdet.core import mask_target
+from mmdet.core import mask_target, force_fp32, auto_fp16
 
 
 @HEADS.register_module
@@ -43,6 +43,7 @@ class FCNMaskHead(nn.Module):
         self.class_agnostic = class_agnostic
         self.conv_cfg = conv_cfg
         self.norm_cfg = norm_cfg
+        self.fp16_enabled = False
         self.loss_mask = build_loss(loss_mask)
 
         self.convs = nn.ModuleList()
@@ -88,6 +89,7 @@ class FCNMaskHead(nn.Module):
                 m.weight, mode='fan_out', nonlinearity='relu')
             nn.init.constant_(m.bias, 0)
 
+    @auto_fp16()
     def forward(self, x):
         for conv in self.convs:
             x = conv(x)
@@ -107,6 +109,7 @@ class FCNMaskHead(nn.Module):
                                    gt_masks, rcnn_train_cfg)
         return mask_targets
 
+    @force_fp32(apply_to=('mask_pred', ))
     def loss(self, mask_pred, mask_targets, labels):
         loss = dict()
         if self.class_agnostic:
@@ -138,6 +141,9 @@ class FCNMaskHead(nn.Module):
         if isinstance(mask_pred, torch.Tensor):
             mask_pred = mask_pred.sigmoid().cpu().numpy()
         assert isinstance(mask_pred, np.ndarray)
+        # when enabling mixed precision training, mask_pred may be float16
+        # numpy array
+        mask_pred = mask_pred.astype(np.float32)
 
         cls_segms = [[] for _ in range(self.num_classes - 1)]
         bboxes = det_bboxes.cpu().numpy()[:, :4]
diff --git a/mmdet/models/mask_heads/fused_semantic_head.py b/mmdet/models/mask_heads/fused_semantic_head.py
index 6107423..550e08e 100644
--- a/mmdet/models/mask_heads/fused_semantic_head.py
+++ b/mmdet/models/mask_heads/fused_semantic_head.py
@@ -2,6 +2,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from mmcv.cnn import kaiming_init
 
+from mmdet.core import auto_fp16, force_fp32
 from ..registry import HEADS
 from ..utils import ConvModule
 
@@ -43,6 +44,7 @@ class FusedSemanticHead(nn.Module):
         self.loss_weight = loss_weight
         self.conv_cfg = conv_cfg
         self.norm_cfg = norm_cfg
+        self.fp16_enabled = False
 
         self.lateral_convs = nn.ModuleList()
         for i in range(self.num_ins):
@@ -79,6 +81,7 @@ class FusedSemanticHead(nn.Module):
     def init_weights(self):
         kaiming_init(self.conv_logits)
 
+    @auto_fp16()
     def forward(self, feats):
         x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
         fused_size = tuple(x.shape[-2:])
@@ -95,6 +98,7 @@ class FusedSemanticHead(nn.Module):
         x = self.conv_embedding(x)
         return mask_pred, x
 
+    @force_fp32(apply_to=('mask_pred',))
     def loss(self, mask_pred, labels):
         labels = labels.squeeze(1).long()
         loss_semantic_seg = self.criterion(mask_pred, labels)
diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
index 6b8c862..d42fb1d 100644
--- a/mmdet/models/necks/fpn.py
+++ b/mmdet/models/necks/fpn.py
@@ -2,6 +2,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from mmcv.cnn import xavier_init
 
+from mmdet.core import auto_fp16
 from ..registry import NECKS
 from ..utils import ConvModule
 
@@ -29,6 +30,7 @@ class FPN(nn.Module):
         self.num_outs = num_outs
         self.activation = activation
         self.relu_before_extra_convs = relu_before_extra_convs
+        self.fp16_enabled = False
 
         if end_level == -1:
             self.backbone_end_level = self.num_ins
@@ -94,6 +96,7 @@ class FPN(nn.Module):
             if isinstance(m, nn.Conv2d):
                 xavier_init(m, distribution='uniform')
 
+    @auto_fp16()
     def forward(self, inputs):
         assert len(inputs) == len(self.in_channels)
 
diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py
index 32709d5..6a1c1e7 100644
--- a/mmdet/models/roi_extractors/single_level.py
+++ b/mmdet/models/roi_extractors/single_level.py
@@ -4,6 +4,7 @@ import torch
 import torch.nn as nn
 
 from mmdet import ops
+from mmdet.core import force_fp32
 from ..registry import ROI_EXTRACTORS
 
 
@@ -31,6 +32,7 @@ class SingleRoIExtractor(nn.Module):
         self.out_channels = out_channels
         self.featmap_strides = featmap_strides
         self.finest_scale = finest_scale
+        self.fp16_enabled = False
 
     @property
     def num_inputs(self):
@@ -70,6 +72,7 @@ class SingleRoIExtractor(nn.Module):
         target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
         return target_lvls
 
+    @force_fp32(apply_to=('feats',), out_fp16=True)
     def forward(self, feats, rois):
         if len(feats) == 1:
             return self.roi_layers[0](feats[0], rois)
@@ -77,8 +80,8 @@ class SingleRoIExtractor(nn.Module):
         out_size = self.roi_layers[0].out_size
         num_levels = len(feats)
         target_lvls = self.map_roi_levels(rois, num_levels)
-        roi_feats = torch.cuda.FloatTensor(rois.size()[0], self.out_channels,
-                                           out_size, out_size).fill_(0)
+        roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels,
+                                       out_size, out_size)
         for i in range(num_levels):
             inds = target_lvls == i
             if inds.any():
diff --git a/mmdet/models/shared_heads/res_layer.py b/mmdet/models/shared_heads/res_layer.py
index 743c2ee..cbc77ac 100644
--- a/mmdet/models/shared_heads/res_layer.py
+++ b/mmdet/models/shared_heads/res_layer.py
@@ -4,6 +4,7 @@ import torch.nn as nn
 from mmcv.cnn import constant_init, kaiming_init
 from mmcv.runner import load_checkpoint
 
+from mmdet.core import auto_fp16
 from ..backbones import ResNet, make_res_layer
 from ..registry import SHARED_HEADS
 
@@ -25,6 +26,7 @@ class ResLayer(nn.Module):
         self.norm_eval = norm_eval
         self.norm_cfg = norm_cfg
         self.stage = stage
+        self.fp16_enabled = False
         block, stage_blocks = ResNet.arch_settings[depth]
         stage_block = stage_blocks[stage]
         planes = 64 * 2**stage
@@ -56,6 +58,7 @@ class ResLayer(nn.Module):
         else:
             raise TypeError('pretrained must be a str or None')
 
+    @auto_fp16()
     def forward(self, x):
         res_layer = getattr(self, 'layer{}'.format(self.stage + 1))
         out = res_layer(x)
diff --git a/tools/test.py b/tools/test.py
index df629f3..54f074d 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -11,7 +11,7 @@ from mmcv.runner import load_checkpoint, get_dist_info
 from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
 
 from mmdet.apis import init_dist
-from mmdet.core import results2json, coco_eval
+from mmdet.core import results2json, coco_eval, wrap_fp16_model
 from mmdet.datasets import build_dataloader, get_dataset
 from mmdet.models import build_detector
 
@@ -157,6 +157,9 @@ def main():
 
     # build the model and load checkpoint
     model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+    fp16_cfg = cfg.get('fp16', None)
+    if fp16_cfg is not None:
+        wrap_fp16_model(model)
     checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
     # old versions did not save class info in checkpoints, this walkaround is
     # for backward compatibility
-- 
GitLab