Commit 1c6b58be authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppg single optim

parent c1c85730
......@@ -51,7 +51,10 @@ class CustomTorchPolicy(TorchPolicy):
network_params = set(self.model.parameters())
aux_optim_params = list(network_params - value_params)
ppo_optim_params = list(network_params - aux_params - value_params)
self.optimizer = torch.optim.Adam(ppo_optim_params, lr=self.config['lr'])
if not self.config['single_optimizer']:
self.optimizer = torch.optim.Adam(ppo_optim_params, lr=self.config['lr'])
else:
self.optimizer = torch.optim.Adam(network_params, lr=self.config['lr'])
self.aux_optimizer = torch.optim.Adam(aux_optim_params, lr=self.config['aux_lr'])
self.value_optimizer = torch.optim.Adam(value_params, lr=self.config['value_lr'])
self.max_reward = self.config['env_config']['return_max']
......@@ -201,10 +204,7 @@ class CustomTorchPolicy(TorchPolicy):
## Distill with aux head
should_retune = self.retune_selector.update(unroll(obs, ts), mb_returns)
if should_retune:
import time
tnow = time.perf_counter()
self.aux_train()
print("Aux Train %fs" % (time.perf_counter()-tnow))
self.update_gamma(samples)
self.update_lr()
......@@ -243,9 +243,10 @@ class CustomTorchPolicy(TorchPolicy):
vf_loss.backward()
if apply_grad:
self.optimizer.step()
self.value_optimizer.step()
self.optimizer.zero_grad()
self.value_optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.step()
self.value_optimizer.zero_grad()
def aux_train(self):
......@@ -282,8 +283,11 @@ class CustomTorchPolicy(TorchPolicy):
loss, vf_loss = self._aux_calc_loss(obs_in, target_vf, target_pi)
loss.backward()
vf_loss.backward()
self.aux_optimizer.step()
self.value_optimizer.step()
if not self.config['single_optimizer']:
self.aux_optimizer.step()
self.value_optimizer.step()
else:
self.optimizer.step()
else:
with autocast():
......@@ -292,13 +296,19 @@ class CustomTorchPolicy(TorchPolicy):
self.amp_scaler.scale(loss).backward(retain_graph=True)
self.amp_scaler.scale(vf_loss).backward()
self.amp_scaler.step(self.aux_optimizer)
self.amp_scaler.step(self.value_optimizer)
if not self.config['single_optimizer']:
self.amp_scaler.step(self.aux_optimizer)
self.amp_scaler.step(self.value_optimizer)
else:
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
self.aux_optimizer.zero_grad()
self.value_optimizer.zero_grad()
if not self.config['single_optimizer']:
self.aux_optimizer.zero_grad()
self.value_optimizer.zero_grad()
else:
self.optimizer.zero_grad()
def _aux_calc_loss(self, obs_in, target_vf, target_pi):
vpred, pi_logits = self.model.vf_pi(obs_in, ret_numpy=False, no_grad=False, to_torch=False)
......
......@@ -95,6 +95,7 @@ DEFAULT_CONFIG = with_common_config({
"value_lr": 1e-3,
"same_lr_everywhere": False,
"aux_phase_mixed_precision": False,
"single_optimizer": False,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -60,6 +60,7 @@ procgen-ppo:
value_lr: 1.0e-3
same_lr_everywhere: False
aux_phase_mixed_precision: True
single_optimizer: True
adaptive_gamma: False
final_lr: 5.0e-5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment