Skip to content
Snippets Groups Projects
Commit bd8fd27e authored by Cao Yuhang's avatar Cao Yuhang
Browse files

rename pos_mask to pos_inds

parent cb68807f
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy)
from ..registry import HEADS
......@@ -94,16 +94,16 @@ class BBoxHead(nn.Module):
cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
pos_mask = labels > 0
pos_inds = labels > 0
if self.reg_class_agnostic:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask]
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_inds]
else:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
4)[pos_mask, labels[pos_mask]]
4)[pos_inds, labels[pos_inds]]
losses['loss_reg'] = weighted_smoothl1(
pos_bbox_pred,
bbox_targets[pos_mask],
bbox_weights[pos_mask],
bbox_targets[pos_inds],
bbox_weights[pos_inds],
avg_factor=bbox_targets.size(0))
return losses
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment