From d9172919d00695f0dddb0ab10c89c3b7a8e26427 Mon Sep 17 00:00:00 2001
From: Jon Crall <erotemic@gmail.com>
Date: Wed, 23 Oct 2019 11:54:26 -0400
Subject: [PATCH] Fix SSD Head and GHM Loss on CPU (#1578)

* Fix GHM loss on CPU

* Fix SSD head on CPU
---
 mmdet/models/anchor_heads/ssd_head.py |  4 +++-
 mmdet/models/losses/ghm_loss.py       | 12 ++++++++----
 2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py
index 600dd4a..056dfa5 100644
--- a/mmdet/models/anchor_heads/ssd_head.py
+++ b/mmdet/models/anchor_heads/ssd_head.py
@@ -143,8 +143,10 @@ class SSDHead(AnchorHead):
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
         assert len(featmap_sizes) == len(self.anchor_generators)
 
+        device = cls_scores[0].device
+
         anchor_list, valid_flag_list = self.get_anchors(
-            featmap_sizes, img_metas)
+            featmap_sizes, img_metas, device=device)
         cls_reg_targets = anchor_target(
             anchor_list,
             valid_flag_list,
diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py
index 95656a2..e62b990 100644
--- a/mmdet/models/losses/ghm_loss.py
+++ b/mmdet/models/losses/ghm_loss.py
@@ -35,10 +35,12 @@ class GHMC(nn.Module):
         super(GHMC, self).__init__()
         self.bins = bins
         self.momentum = momentum
-        self.edges = torch.arange(bins + 1).float().cuda() / bins
+        edges = torch.arange(bins + 1).float() / bins
+        self.register_buffer('edges', edges)
         self.edges[-1] += 1e-6
         if momentum > 0:
-            self.acc_sum = torch.zeros(bins).cuda()
+            acc_sum = torch.zeros(bins)
+            self.register_buffer('acc_sum', acc_sum)
         self.use_sigmoid = use_sigmoid
         if not self.use_sigmoid:
             raise NotImplementedError
@@ -111,11 +113,13 @@ class GHMR(nn.Module):
         super(GHMR, self).__init__()
         self.mu = mu
         self.bins = bins
-        self.edges = torch.arange(bins + 1).float().cuda() / bins
+        edges = torch.arange(bins + 1).float() / bins
+        self.register_buffer('edges', edges)
         self.edges[-1] = 1e3
         self.momentum = momentum
         if momentum > 0:
-            self.acc_sum = torch.zeros(bins).cuda()
+            acc_sum = torch.zeros(bins)
+            self.register_buffer('acc_sum', acc_sum)
         self.loss_weight = loss_weight
 
     # TODO: support reduction parameter
-- 
GitLab