Commit 0cae17f4 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

init normed

parent 9b28f2ac
......@@ -80,7 +80,7 @@ procgen-ppo:
# nlatents: 1024
depths: [32, 64, 64]
nlatents: 512
init_glorot: False
init_normed: True
use_layernorm: True
num_workers: 7
......
......@@ -8,14 +8,14 @@ torch, nn = try_import_torch()
class ResidualBlock(nn.Module):
def __init__(self, channels, init_glorot=False):
def __init__(self, channels, init_normed=False):
super().__init__()
self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
if init_glorot:
nn.init.xavier_uniform_(self.conv0.weight)
if init_normed:
self.conv0.weight.data *= 1 / self.conv0.weight.norm(dim=(1, 2, 3), p=2, keepdim=True)
nn.init.zeros_(self.conv0.bias)
nn.init.xavier_uniform_(self.conv1.weight)
self.conv1.weight.data *= 1 / self.conv1.weight.norm(dim=(1, 2, 3), p=2, keepdim=True)
nn.init.zeros_(self.conv1.bias)
......@@ -29,15 +29,15 @@ class ResidualBlock(nn.Module):
class ConvSequence(nn.Module):
def __init__(self, input_shape, out_channels, init_glorot=False):
def __init__(self, input_shape, out_channels, init_normed=False):
super().__init__()
self._input_shape = input_shape
self._out_channels = out_channels
self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1)
self.res_block0 = ResidualBlock(self._out_channels, init_glorot)
self.res_block1 = ResidualBlock(self._out_channels, init_glorot)
if init_glorot:
nn.init.xavier_uniform_(self.conv.weight)
self.res_block0 = ResidualBlock(self._out_channels, init_normed)
self.res_block1 = ResidualBlock(self._out_channels, init_normed)
if init_normed:
self.conv.weight.data *= 1 / self.conv.weight.norm(dim=(1, 2, 3), p=2, keepdim=True)
nn.init.zeros_(self.conv.bias)
def forward(self, x, pool=True):
......@@ -71,7 +71,7 @@ class ImpalaCNN(TorchModelV2, nn.Module):
self.device = device
depths = model_config['custom_options'].get('depths') or [16, 32, 32]
nlatents = model_config['custom_options'].get('nlatents') or 256
init_glorot = model_config['custom_options'].get('init_glorot') or False
init_normed = model_config['custom_options'].get('init_normed') or False
self.use_layernorm = model_config['custom_options'].get('use_layernorm') or True
h, w, c = obs_space.shape
......@@ -79,19 +79,23 @@ class ImpalaCNN(TorchModelV2, nn.Module):
conv_seqs = []
for out_channels in depths:
conv_seq = ConvSequence(shape, out_channels, init_glorot)
conv_seq = ConvSequence(shape, out_channels, init_normed)
shape = conv_seq.get_output_shape()
conv_seqs.append(conv_seq)
self.conv_seqs = nn.ModuleList(conv_seqs)
self.hidden_fc = nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=nlatents)
if init_glorot:
nn.init.xavier_uniform_(self.hidden_fc.weight)
if init_normed:
self.hidden_fc.weight.data *= 1.4 / self.hidden_fc.weight.norm(dim=1, p=2, keepdim=True)
nn.init.zeros_(self.hidden_fc.bias)
self.pi_fc = nn.Linear(in_features=nlatents, out_features=num_outputs)
nn.init.orthogonal_(self.pi_fc.weight, gain=0.01)
nn.init.zeros_(self.pi_fc.bias)
self.value_fc = nn.Linear(in_features=nlatents, out_features=1)
nn.init.orthogonal_(self.value_fc.weight, gain=1)
if init_normed:
self.pi_fc.weight.data *= 0.1 / self.pi_fc.weight.norm(dim=1, p=2, keepdim=True)
self.value_fc.weight.data *= 0.1 / self.value_fc.weight.norm(dim=1, p=2, keepdim=True)
else:
nn.init.orthogonal_(self.pi_fc.weight, gain=0.01)
nn.init.orthogonal_(self.value_fc.weight, gain=1)
nn.init.zeros_(self.pi_fc.bias)
nn.init.zeros_(self.value_fc.bias)
if self.use_layernorm:
self.layernorm = nn.LayerNorm(nlatents)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment