diff --git a/mmdet/models/mask_heads/maskiou_head.py b/mmdet/models/mask_heads/maskiou_head.py index bfa1764345d1ef399ee66d9da11760086db0076f..457a560cc1fe06a6ee390226aa7b92111e5e0343 100644 --- a/mmdet/models/mask_heads/maskiou_head.py +++ b/mmdet/models/mask_heads/maskiou_head.py @@ -2,6 +2,7 @@ import numpy as np import torch import torch.nn as nn from mmcv.cnn import kaiming_init, normal_init +from mmdet.core import force_fp32 from ..builder import build_loss from ..registry import HEADS @@ -28,6 +29,7 @@ class MaskIoUHead(nn.Module): self.conv_out_channels = conv_out_channels self.fc_out_channels = fc_out_channels self.num_classes = num_classes + self.fp16_enabled = False self.convs = nn.ModuleList() for i in range(num_convs): @@ -82,6 +84,7 @@ class MaskIoUHead(nn.Module): mask_iou = self.fc_mask_iou(x) return mask_iou + @force_fp32(apply_to=('mask_iou_pred', )) def loss(self, mask_iou_pred, mask_iou_targets): pos_inds = mask_iou_targets > 0 if pos_inds.sum() > 0: @@ -91,6 +94,7 @@ class MaskIoUHead(nn.Module): loss_mask_iou = mask_iou_pred * 0 return dict(loss_mask_iou=loss_mask_iou) + @force_fp32(apply_to=('mask_pred', )) def get_target(self, sampling_results, gt_masks, mask_pred, mask_targets, rcnn_train_cfg): """Compute target of mask IoU. @@ -166,6 +170,7 @@ class MaskIoUHead(nn.Module): area_ratios = pos_proposals.new_zeros((0, )) return area_ratios + @force_fp32(apply_to=('mask_iou_pred', )) def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels): """Get the mask scores.