Skip to content
Snippets Groups Projects
Commit c899cdf1 authored by Kai Chen's avatar Kai Chen
Browse files

bug fix for freezing parameters

parent 6fe5ccde
No related branches found
No related tags found
No related merge requests found
...@@ -421,12 +421,14 @@ class ResNet(nn.Module): ...@@ -421,12 +421,14 @@ class ResNet(nn.Module):
def _freeze_stages(self): def _freeze_stages(self):
if self.frozen_stages >= 0: if self.frozen_stages >= 0:
self.norm1.eval()
for m in [self.conv1, self.norm1]: for m in [self.conv1, self.norm1]:
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
for i in range(1, self.frozen_stages + 1): for i in range(1, self.frozen_stages + 1):
m = getattr(self, 'layer{}'.format(i)) m = getattr(self, 'layer{}'.format(i))
m.eval()
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment