From d9172919d00695f0dddb0ab10c89c3b7a8e26427 Mon Sep 17 00:00:00 2001 From: Jon Crall <erotemic@gmail.com> Date: Wed, 23 Oct 2019 11:54:26 -0400 Subject: [PATCH] Fix SSD Head and GHM Loss on CPU (#1578) * Fix GHM loss on CPU * Fix SSD head on CPU --- mmdet/models/anchor_heads/ssd_head.py | 4 +++- mmdet/models/losses/ghm_loss.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/mmdet/models/anchor_heads/ssd_head.py b/mmdet/models/anchor_heads/ssd_head.py index 600dd4a..056dfa5 100644 --- a/mmdet/models/anchor_heads/ssd_head.py +++ b/mmdet/models/anchor_heads/ssd_head.py @@ -143,8 +143,10 @@ class SSDHead(AnchorHead): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == len(self.anchor_generators) + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( - featmap_sizes, img_metas) + featmap_sizes, img_metas, device=device) cls_reg_targets = anchor_target( anchor_list, valid_flag_list, diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py index 95656a2..e62b990 100644 --- a/mmdet/models/losses/ghm_loss.py +++ b/mmdet/models/losses/ghm_loss.py @@ -35,10 +35,12 @@ class GHMC(nn.Module): super(GHMC, self).__init__() self.bins = bins self.momentum = momentum - self.edges = torch.arange(bins + 1).float().cuda() / bins + edges = torch.arange(bins + 1).float() / bins + self.register_buffer('edges', edges) self.edges[-1] += 1e-6 if momentum > 0: - self.acc_sum = torch.zeros(bins).cuda() + acc_sum = torch.zeros(bins) + self.register_buffer('acc_sum', acc_sum) self.use_sigmoid = use_sigmoid if not self.use_sigmoid: raise NotImplementedError @@ -111,11 +113,13 @@ class GHMR(nn.Module): super(GHMR, self).__init__() self.mu = mu self.bins = bins - self.edges = torch.arange(bins + 1).float().cuda() / bins + edges = torch.arange(bins + 1).float() / bins + self.register_buffer('edges', edges) self.edges[-1] = 1e3 self.momentum = momentum if momentum > 0: - self.acc_sum = torch.zeros(bins).cuda() + acc_sum = torch.zeros(bins) + self.register_buffer('acc_sum', acc_sum) self.loss_weight = loss_weight # TODO: support reduction parameter -- GitLab