Skip to content
Snippets Groups Projects
Commit cb68807f authored by Cao Yuhang's avatar Cao Yuhang
Browse files

remove expand loop in bbox head to speed up

parent d5a6b5dc
No related branches found
No related tags found
No related merge requests found
......@@ -57,9 +57,6 @@ def bbox_target_single(pos_bboxes,
bbox_weights[:num_pos, :] = 1
if num_neg > 0:
label_weights[-num_neg:] = 1.0
if reg_classes > 1:
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
labels, reg_classes)
return labels, label_weights, bbox_targets, bbox_weights
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy)
from ..registry import HEADS
......@@ -94,10 +94,16 @@ class BBoxHead(nn.Module):
cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
pos_mask = labels > 0
if self.reg_class_agnostic:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask]
else:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
4)[pos_mask, labels[pos_mask]]
losses['loss_reg'] = weighted_smoothl1(
bbox_pred,
bbox_targets,
bbox_weights,
pos_bbox_pred,
bbox_targets[pos_mask],
bbox_weights[pos_mask],
avg_factor=bbox_targets.size(0))
return losses
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment