Commit 2ec7664b authored by Chakraborty's avatar Chakraborty
Browse files
parents 18adb202 c4d01795
...@@ -301,13 +301,17 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -301,13 +301,17 @@ class CustomTorchPolicy(TorchPolicy):
# Tune vf and pi heads to older predictions with (augmented?) observations # Tune vf and pi heads to older predictions with (augmented?) observations
for ep in range(retune_epochs): for ep in range(retune_epochs):
counter = 0
for slices in self.retune_selector.make_minibatches(replay_pi): for slices in self.retune_selector.make_minibatches(replay_pi):
self.tune_policy(slices[0], self.to_tensor(slices[1]), self.to_tensor(slices[2])) counter += 1
apply_grad = (counter % 2) == 0
self.tune_policy(slices[0], self.to_tensor(slices[1]), self.to_tensor(slices[2]),
apply_grad, num_accumulate=2)
self.retunes_completed += 1 self.retunes_completed += 1
self.retune_selector.retune_done() self.retune_selector.retune_done()
def tune_policy(self, obs, target_vf, target_pi): def tune_policy(self, obs, target_vf, target_pi, apply_grad, num_accumulate):
if self.config['augment_buffer']: if self.config['augment_buffer']:
obs_aug = np.empty(obs.shape, obs.dtype) obs_aug = np.empty(obs.shape, obs.dtype)
aug_idx = np.random.randint(self.config['augment_randint_num'], size=len(obs)) aug_idx = np.random.randint(self.config['augment_randint_num'], size=len(obs))
...@@ -319,38 +323,42 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -319,38 +323,42 @@ class CustomTorchPolicy(TorchPolicy):
obs_in = self.to_tensor(obs) obs_in = self.to_tensor(obs)
if not self.config['aux_phase_mixed_precision']: if not self.config['aux_phase_mixed_precision']:
loss, vf_loss = self._aux_calc_loss(obs_in, target_vf, target_pi) loss, vf_loss = self._aux_calc_loss(obs_in, target_vf, target_pi, num_accumulate)
loss.backward() loss.backward()
vf_loss.backward() vf_loss.backward()
if not self.config['single_optimizer']:
self.aux_optimizer.step() if apply_grad:
self.value_optimizer.step() if not self.config['single_optimizer']:
else: self.aux_optimizer.step()
self.optimizer.step() self.value_optimizer.step()
else:
self.optimizer.step()
else: else:
with autocast(): with autocast():
loss, vf_loss = self._aux_calc_loss(obs_in, target_vf, target_pi) loss, vf_loss = self._aux_calc_loss(obs_in, target_vf, target_pi, num_accumulate)
self.amp_scaler.scale(loss).backward(retain_graph=True) self.amp_scaler.scale(loss).backward(retain_graph=True)
self.amp_scaler.scale(vf_loss).backward() self.amp_scaler.scale(vf_loss).backward()
if apply_grad:
if not self.config['single_optimizer']:
self.amp_scaler.step(self.aux_optimizer)
self.amp_scaler.step(self.value_optimizer)
else:
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
if apply_grad:
if not self.config['single_optimizer']: if not self.config['single_optimizer']:
self.amp_scaler.step(self.aux_optimizer) self.aux_optimizer.zero_grad()
self.amp_scaler.step(self.value_optimizer) self.value_optimizer.zero_grad()
else: else:
self.amp_scaler.step(self.optimizer) self.optimizer.zero_grad()
self.amp_scaler.update()
if not self.config['single_optimizer']:
self.aux_optimizer.zero_grad()
self.value_optimizer.zero_grad()
else:
self.optimizer.zero_grad()
def _aux_calc_loss(self, obs_in, target_vf, target_pi): def _aux_calc_loss(self, obs_in, target_vf, target_pi, num_accumulate):
vpred, pi_logits = self.model.vf_pi(obs_in, ret_numpy=False, no_grad=False, to_torch=False) vpred, pi_logits = self.model.vf_pi(obs_in, ret_numpy=False, no_grad=False, to_torch=False)
aux_vpred = self.model.aux_value_function() aux_vpred = self.model.aux_value_function()
aux_loss = .5 * torch.mean(torch.pow(aux_vpred - target_vf, 2)) aux_loss = .5 * torch.mean(torch.pow(aux_vpred - target_vf, 2))
...@@ -362,6 +370,9 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -362,6 +370,9 @@ class CustomTorchPolicy(TorchPolicy):
loss = pi_loss + aux_loss loss = pi_loss + aux_loss
vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2)) vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
loss = loss / num_accumulate
vf_loss = vf_loss / num_accumulate
return loss, vf_loss return loss, vf_loss
def best_reward_model_select(self, samples): def best_reward_model_select(self, samples):
......
...@@ -45,9 +45,9 @@ procgen-ppo: ...@@ -45,9 +45,9 @@ procgen-ppo:
no_done_at_end: False no_done_at_end: False
# Custom switches # Custom switches
skips: 2 skips: 6
n_pi: 14 n_pi: 10
num_retunes: 12 num_retunes: 16
retune_epochs: 6 retune_epochs: 6
standardize_rewards: True standardize_rewards: True
aux_mbsize: 4 aux_mbsize: 4
...@@ -65,7 +65,7 @@ procgen-ppo: ...@@ -65,7 +65,7 @@ procgen-ppo:
pi_phase_mixed_precision: True pi_phase_mixed_precision: True
adaptive_gamma: False adaptive_gamma: False
final_lr: 5.0e-5 final_lr: 1.0e-4
lr_schedule: 'linear' lr_schedule: 'linear'
final_entropy_coeff: 0.002 final_entropy_coeff: 0.002
entropy_schedule: False entropy_schedule: False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment