Commit 925d66aa authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

normalize adv by all rollouts

parent f48651b2
......@@ -163,18 +163,14 @@ class CustomTorchPolicy(TorchPolicy):
actions = samples['actions']
returns = roll(mb_returns)
advs = returns - values
normalized_advs = (advs - np.mean(advs)) / (np.std(advs) + 1e-8)
## Train multiple epochs
optim_count = 0
inds = np.arange(nbatch)
for _ in range(noptepochs):
np.random.shuffle(inds)
normalized_advs = returns - values
# Can do this because actual_batch_size is a multiple of mem_limited_batch_size
for start in range(0, nbatch, self.actual_batch_size):
end = start + self.actual_batch_size
mbinds = inds[start:end]
advs_batch = normalized_advs[mbinds].copy()
normalized_advs[mbinds] = (advs_batch - np.mean(advs_batch)) / (np.std(advs_batch) + 1e-8)
for start in range(0, nbatch, nbatch_train):
end = start + nbatch_train
mbinds = inds[start:end]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment