From cb68807ffcaa157b9ee0826fecd61fb439efe9b3 Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Sun, 14 Apr 2019 12:38:24 +0800
Subject: [PATCH] remove expand loop in bbox head to speed up

---
 mmdet/core/bbox/bbox_target.py       |  3 ---
 mmdet/models/bbox_heads/bbox_head.py | 14 ++++++++++----
 2 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/mmdet/core/bbox/bbox_target.py b/mmdet/core/bbox/bbox_target.py
index 4a0450d..aa1fbc6 100644
--- a/mmdet/core/bbox/bbox_target.py
+++ b/mmdet/core/bbox/bbox_target.py
@@ -57,9 +57,6 @@ def bbox_target_single(pos_bboxes,
         bbox_weights[:num_pos, :] = 1
     if num_neg > 0:
         label_weights[-num_neg:] = 1.0
-    if reg_classes > 1:
-        bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
-                                                   labels, reg_classes)
 
     return labels, label_weights, bbox_targets, bbox_weights
 
diff --git a/mmdet/models/bbox_heads/bbox_head.py b/mmdet/models/bbox_heads/bbox_head.py
index 2168e2e..092a812 100644
--- a/mmdet/models/bbox_heads/bbox_head.py
+++ b/mmdet/models/bbox_heads/bbox_head.py
@@ -1,9 +1,9 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-
 from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
                         weighted_cross_entropy, weighted_smoothl1, accuracy)
+
 from ..registry import HEADS
 
 
@@ -94,10 +94,16 @@ class BBoxHead(nn.Module):
                 cls_score, labels, label_weights, reduce=reduce)
             losses['acc'] = accuracy(cls_score, labels)
         if bbox_pred is not None:
+            pos_mask = labels > 0
+            if self.reg_class_agnostic:
+                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), 4)[pos_mask]
+            else:
+                pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
+                                               4)[pos_mask, labels[pos_mask]]
             losses['loss_reg'] = weighted_smoothl1(
-                bbox_pred,
-                bbox_targets,
-                bbox_weights,
+                pos_bbox_pred,
+                bbox_targets[pos_mask],
+                bbox_weights[pos_mask],
                 avg_factor=bbox_targets.size(0))
         return losses
 
-- 
GitLab