Commit c4a8cd1a authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppo expt changes

parent d28d29d6
...@@ -171,7 +171,7 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -171,7 +171,7 @@ class CustomTorchPolicy(TorchPolicy):
max_grad_norm = self.config['grad_clip'] max_grad_norm = self.config['grad_clip']
ent_coef, vf_coef = self.ent_coef, self.config['vf_loss_coeff'] ent_coef, vf_coef = self.ent_coef, self.config['vf_loss_coeff']
vf_coef *= self.timesteps_total / self.target_timesteps vf_coef_now = vf_coef * self.timesteps_total / self.target_timesteps
neglogpacs = -samples['action_logp'] ## np.isclose seems to be True always, otherwise compute again if needed neglogpacs = -samples['action_logp'] ## np.isclose seems to be True always, otherwise compute again if needed
noptepochs = self.config['num_sgd_iter'] noptepochs = self.config['num_sgd_iter']
...@@ -197,7 +197,7 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -197,7 +197,7 @@ class CustomTorchPolicy(TorchPolicy):
optim_count += 1 optim_count += 1
apply_grad = (optim_count % self.accumulate_train_batches) == 0 apply_grad = (optim_count % self.accumulate_train_batches) == 0
self._batch_train(apply_grad, self.accumulate_train_batches, self._batch_train(apply_grad, self.accumulate_train_batches,
lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices) lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef_now, *slices)
## Distill with augmentation ## Distill with augmentation
should_retune = self.retune_selector.update(obs, self.exp_replay) should_retune = self.retune_selector.update(obs, self.exp_replay)
...@@ -241,18 +241,20 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -241,18 +241,20 @@ class CustomTorchPolicy(TorchPolicy):
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
loss = loss / num_accumulate loss = loss / num_accumulate
latent = self.model._detached_latent
det_value = self.model.value_fc(latent).squeeze(1)
det_value_loss = .5 * torch.pow((det_value - returns), 2).mean()
# Weird hack to remove vf_loss grad and only keep det_value_loss grad # Weird hack to remove vf_loss grad and only keep det_value_loss grad
if self.model.value_fc.weight.grad is not None: if self.model.value_fc.weight.grad is not None:
self.value_fc_old_grads = [self.model.value_fc.weight.grad.clone(), self.model.value_fc.bias.grad.clone()] self.value_fc_old_grads = [self.model.value_fc.weight.grad.clone(), self.model.value_fc.bias.grad.clone()]
else: else:
self.value_fc_old_grads = None self.value_fc_old_grads = None
loss.backward() loss.backward()
self.model.value_fc.zero_grad() self.model.value_fc.zero_grad()
latent = self.model._detached_latent
det_value = self.model.value_fc(latent).squeeze(1)
det_value_loss = .5 * torch.pow((det_value - returns), 2).mean()
det_value_loss.backward() det_value_loss.backward()
if self.value_fc_old_grads is not None: if self.value_fc_old_grads is not None:
self.model.value_fc.weight.grad += self.value_fc_old_grads[0] self.model.value_fc.weight.grad += self.value_fc_old_grads[0]
self.model.value_fc.bias.grad += self.value_fc_old_grads[1] self.model.value_fc.bias.grad += self.value_fc_old_grads[1]
......
...@@ -33,7 +33,7 @@ procgen-ppo: ...@@ -33,7 +33,7 @@ procgen-ppo:
lambda: 0.95 lambda: 0.95
lr: 5.0e-4 lr: 5.0e-4
# Number of SGD iterations in each outer loop # Number of SGD iterations in each outer loop
num_sgd_iter: 3 num_sgd_iter: 1
vf_loss_coeff: 0.5 vf_loss_coeff: 0.5
entropy_coeff: 0.01 entropy_coeff: 0.01
clip_param: 0.2 clip_param: 0.2
...@@ -46,10 +46,10 @@ procgen-ppo: ...@@ -46,10 +46,10 @@ procgen-ppo:
no_done_at_end: False no_done_at_end: False
# Custom switches # Custom switches
retune_skips: 350000 retune_skips: 50000
retune_replay_size: 200000 retune_replay_size: 500000
num_retunes: 13 num_retunes: 13
retune_epochs: 3 retune_epochs: 6SS
standardize_rewards: True standardize_rewards: True
scale_reward: 1.0 scale_reward: 1.0
return_reset: False return_reset: False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment