Skip to content
Snippets Groups Projects
Commit 6668bf03 authored by Zhuliang Yao's avatar Zhuliang Yao Committed by Kai Chen
Browse files

fix a scale bug according to Non-Local papaer (#1528)

parent 0527e210
No related branches found
No related tags found
No related merge requests found
...@@ -76,7 +76,7 @@ class NonLocal2D(nn.Module): ...@@ -76,7 +76,7 @@ class NonLocal2D(nn.Module):
pairwise_weight = torch.matmul(theta_x, phi_x) pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale: if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels` # theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**-0.5 pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight = pairwise_weight.softmax(dim=-1) pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight return pairwise_weight
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment