Commit bbe386a5 authored by Chakraborty's avatar Chakraborty
Browse files
parents 8f996f60 c63cee77
...@@ -228,6 +228,48 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -228,6 +228,48 @@ class CustomTorchPolicy(TorchPolicy):
ent_coef, vf_coef, ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs): obs, returns, actions, values, logp_actions_old, advs):
if not self.config['pi_phase_mixed_precision']:
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
loss.backward()
vf_loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.step()
self.value_optimizer.zero_grad()
else:
with autocast():
loss, vf_loss = self._calc_pi_vf_loss(apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs)
self.amp_scaler.scale(loss).backward(retain_graph=True)
self.amp_scaler.scale(vf_loss).backward()
if apply_grad:
self.amp_scaler.step(self.optimizer)
if not self.config['single_optimizer']:
self.amp_scaler.step(self.value_optimizer)
self.amp_scaler.update()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.zero_grad()
def _calc_pi_vf_loss(self, apply_grad, num_accumulate,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs):
vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False) vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False)
pd = self.make_distr(pi_logits) pd = self.make_distr(pi_logits)
logp_actions = pd.log_prob(actions[...,None]).squeeze(1) logp_actions = pd.log_prob(actions[...,None]).squeeze(1)
...@@ -244,17 +286,8 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -244,17 +286,8 @@ class CustomTorchPolicy(TorchPolicy):
loss = loss / num_accumulate loss = loss / num_accumulate
vf_loss = vf_loss / num_accumulate vf_loss = vf_loss / num_accumulate
return loss, vf_loss
loss.backward()
vf_loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
if not self.config['single_optimizer']:
self.value_optimizer.step()
self.value_optimizer.zero_grad()
def aux_train(self): def aux_train(self):
nbatch_train = self.mem_limited_batch_size nbatch_train = self.mem_limited_batch_size
retune_epochs = self.config['retune_epochs'] retune_epochs = self.config['retune_epochs']
...@@ -294,6 +327,7 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -294,6 +327,7 @@ class CustomTorchPolicy(TorchPolicy):
self.value_optimizer.step() self.value_optimizer.step()
else: else:
self.optimizer.step() self.optimizer.step()
else: else:
with autocast(): with autocast():
......
...@@ -97,6 +97,7 @@ DEFAULT_CONFIG = with_common_config({ ...@@ -97,6 +97,7 @@ DEFAULT_CONFIG = with_common_config({
"aux_phase_mixed_precision": False, "aux_phase_mixed_precision": False,
"single_optimizer": False, "single_optimizer": False,
"max_time": 7200, "max_time": 7200,
"pi_phase_mixed_precision": False,
}) })
# __sphinx_doc_end__ # __sphinx_doc_end__
# yapf: enable # yapf: enable
......
...@@ -45,9 +45,9 @@ procgen-ppo: ...@@ -45,9 +45,9 @@ procgen-ppo:
no_done_at_end: False no_done_at_end: False
# Custom switches # Custom switches
skips: 2 skips: 6
n_pi: 16 n_pi: 10
num_retunes: 14 num_retunes: 16
retune_epochs: 6 retune_epochs: 6
standardize_rewards: True standardize_rewards: True
aux_mbsize: 4 aux_mbsize: 4
...@@ -62,15 +62,16 @@ procgen-ppo: ...@@ -62,15 +62,16 @@ procgen-ppo:
aux_phase_mixed_precision: True aux_phase_mixed_precision: True
single_optimizer: True single_optimizer: True
max_time: 7200 max_time: 7200
pi_phase_mixed_precision: True
adaptive_gamma: False adaptive_gamma: False
final_lr: 5.0e-5 final_lr: 1.0e-4
lr_schedule: 'linear' lr_schedule: 'linear'
final_entropy_coeff: 0.002 final_entropy_coeff: 0.002
entropy_schedule: False entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it # Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 1000 max_minibatch_size: 500
updates_per_batch: 8 updates_per_batch: 8
normalize_actions: False normalize_actions: False
...@@ -87,8 +88,10 @@ procgen-ppo: ...@@ -87,8 +88,10 @@ procgen-ppo:
model: model:
custom_model: impala_torch_ppg custom_model: impala_torch_ppg
custom_model_config: custom_model_config:
depths: [32, 64, 64] # depths: [32, 64, 64]
nlatents: 512 # nlatents: 512
depths: [64, 128, 128]
nlatents: 1024
init_normed: True init_normed: True
use_layernorm: False use_layernorm: False
diff_framestack: True diff_framestack: True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment