Commit 864a942f authored by Chakraborty's avatar Chakraborty
Browse files
parents 910b63f7 f473f516
......@@ -217,41 +217,19 @@ class CustomTorchPolicy(TorchPolicy):
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs):
if not self.config['pi_phase_mixed_precision']:
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
loss.backward()
vf_loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.step()
self.value_optimizer.zero_grad()
else:
with autocast():
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
self.amp_scaler.scale(loss).backward(retain_graph=True)
self.amp_scaler.scale(vf_loss).backward()
if apply_grad:
self.amp_scaler.step(self.optimizer)
if not self.config['single_optimizer']:
self.amp_scaler.step(self.value_optimizer)
self.amp_scaler.update()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.zero_grad()
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
loss.backward()
vf_loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.step()
self.value_optimizer.zero_grad()
def _calc_pi_vf_loss(self, apply_grad, num_accumulate,
......@@ -385,9 +363,10 @@ class CustomTorchPolicy(TorchPolicy):
self.best_rew_tsteps = self.timesteps_total
if self.timesteps_total > self.target_timesteps or (self.time_elapsed + self.buffer_time) > self.max_time:
if self.best_weights is not None:
self.set_model_weights(self.best_weights)
return True
if self.timesteps_total > 1_000_000: # Adding this hack due to maze reward deque very high in beginning
if self.best_weights is not None:
self.set_model_weights(self.best_weights)
return True
return False
......
......@@ -45,10 +45,10 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
skips: 2
n_pi: 16
num_retunes: 15
retune_epochs: 7
skips: 0
n_pi: 32
num_retunes: 8
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 4
aux_num_accumulates: 2
......@@ -112,6 +112,11 @@ procgen-ppo:
explore: True,
exploration_config:
type: "StochasticSampling"
evaluation_config:
exploration_config:
type: SoftQ
temperature: 0.5
observation_filter: "NoFilter"
synchronize_filters: True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment