diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py
index 600dd4a13b6473af50fd761260dec92cb555c4e6..056dfa53f25ef7a954be7c553de7329be2c711a2 100644
--- a/mmdet/models/anchor_heads/ssd_head.py
+++ b/mmdet/models/anchor_heads/ssd_head.py
@@ -143,8 +143,10 @@ class SSDHead(AnchorHead):
         featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
         assert len(featmap_sizes) == len(self.anchor_generators)
 
+        device = cls_scores[0].device
+
         anchor_list, valid_flag_list = self.get_anchors(
-            featmap_sizes, img_metas)
+            featmap_sizes, img_metas, device=device)
         cls_reg_targets = anchor_target(
             anchor_list,
             valid_flag_list,
diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py
index 95656a21efb195529ee8c8f747b06e588555d423..e62b9904f56a15fc8b2ee045613afd7254850685 100644
--- a/mmdet/models/losses/ghm_loss.py
+++ b/mmdet/models/losses/ghm_loss.py
@@ -35,10 +35,12 @@ class GHMC(nn.Module):
         super(GHMC, self).__init__()
         self.bins = bins
         self.momentum = momentum
-        self.edges = torch.arange(bins + 1).float().cuda() / bins
+        edges = torch.arange(bins + 1).float() / bins
+        self.register_buffer('edges', edges)
         self.edges[-1] += 1e-6
         if momentum > 0:
-            self.acc_sum = torch.zeros(bins).cuda()
+            acc_sum = torch.zeros(bins)
+            self.register_buffer('acc_sum', acc_sum)
         self.use_sigmoid = use_sigmoid
         if not self.use_sigmoid:
             raise NotImplementedError
@@ -111,11 +113,13 @@ class GHMR(nn.Module):
         super(GHMR, self).__init__()
         self.mu = mu
         self.bins = bins
-        self.edges = torch.arange(bins + 1).float().cuda() / bins
+        edges = torch.arange(bins + 1).float() / bins
+        self.register_buffer('edges', edges)
         self.edges[-1] = 1e3
         self.momentum = momentum
         if momentum > 0:
-            self.acc_sum = torch.zeros(bins).cuda()
+            acc_sum = torch.zeros(bins)
+            self.register_buffer('acc_sum', acc_sum)
         self.loss_weight = loss_weight
 
     # TODO: support reduction parameter