From cb0dd8ee6777ecc658be3798b4ff46fcac966e16 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Sat, 13 Jul 2019 22:54:23 +0800
Subject: [PATCH] support fp16 for maskiou_head (#986)

---
 mmdet/models/mask_heads/maskiou_head.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mmdet/models/mask_heads/maskiou_head.py b/mmdet/models/mask_heads/maskiou_head.py
index bfa1764..457a560 100644
--- a/mmdet/models/mask_heads/maskiou_head.py
+++ b/mmdet/models/mask_heads/maskiou_head.py
@@ -2,6 +2,7 @@ import numpy as np
 import torch
 import torch.nn as nn
 from mmcv.cnn import kaiming_init, normal_init
+from mmdet.core import force_fp32
 
 from ..builder import build_loss
 from ..registry import HEADS
@@ -28,6 +29,7 @@ class MaskIoUHead(nn.Module):
         self.conv_out_channels = conv_out_channels
         self.fc_out_channels = fc_out_channels
         self.num_classes = num_classes
+        self.fp16_enabled = False
 
         self.convs = nn.ModuleList()
         for i in range(num_convs):
@@ -82,6 +84,7 @@ class MaskIoUHead(nn.Module):
         mask_iou = self.fc_mask_iou(x)
         return mask_iou
 
+    @force_fp32(apply_to=('mask_iou_pred', ))
     def loss(self, mask_iou_pred, mask_iou_targets):
         pos_inds = mask_iou_targets > 0
         if pos_inds.sum() > 0:
@@ -91,6 +94,7 @@ class MaskIoUHead(nn.Module):
             loss_mask_iou = mask_iou_pred * 0
         return dict(loss_mask_iou=loss_mask_iou)
 
+    @force_fp32(apply_to=('mask_pred', ))
     def get_target(self, sampling_results, gt_masks, mask_pred, mask_targets,
                    rcnn_train_cfg):
         """Compute target of mask IoU.
@@ -166,6 +170,7 @@ class MaskIoUHead(nn.Module):
             area_ratios = pos_proposals.new_zeros((0, ))
         return area_ratios
 
+    @force_fp32(apply_to=('mask_iou_pred', ))
     def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels):
         """Get the mask scores.
 
-- 
GitLab