Commit 2aeebdfe authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppg aux accumulate

parent c4d01795
......@@ -300,14 +300,15 @@ class CustomTorchPolicy(TorchPolicy):
ret_numpy=True, no_grad=True, to_torch=True)
# Tune vf and pi heads to older predictions with (augmented?) observations
num_accumulate = self.config['aux_num_accumulates']
num_rollouts = self.config['aux_mbsize']
for ep in range(retune_epochs):
counter = 0
for slices in self.retune_selector.make_minibatches(replay_pi):
for slices in self.retune_selector.make_minibatches(replay_pi, num_rollouts):
counter += 1
apply_grad = (counter % 2) == 0
apply_grad = (counter % num_accumulate) == 0
self.tune_policy(slices[0], self.to_tensor(slices[1]), self.to_tensor(slices[2]),
apply_grad, num_accumulate=2)
apply_grad, num_accumulate)
self.retunes_completed += 1
self.retune_selector.retune_done()
......
......@@ -98,6 +98,7 @@ DEFAULT_CONFIG = with_common_config({
"single_optimizer": False,
"max_time": 7200,
"pi_phase_mixed_precision": False,
"aux_num_accumulates": 1,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -144,7 +144,7 @@ class RetuneSelector:
self.replay_index = 0
def make_minibatches(self, presleep_pi, num_rollouts=4):
def make_minibatches(self, presleep_pi, num_rollouts):
if not self.flat_buffer:
env_segs = list(itertools.product(range(self.n_pi), range(self.nenvs)))
np.random.shuffle(env_segs)
......
......@@ -45,12 +45,13 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
skips: 6
n_pi: 10
skips: 0
n_pi: 16
num_retunes: 16
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 4
aux_num_accumulates: 3
augment_buffer: True
scale_reward: 1.0
reset_returns: False
......@@ -62,7 +63,7 @@ procgen-ppo:
aux_phase_mixed_precision: True
single_optimizer: True
max_time: 7200
pi_phase_mixed_precision: True
pi_phase_mixed_precision: False
adaptive_gamma: False
final_lr: 1.0e-4
......@@ -71,7 +72,7 @@ procgen-ppo:
entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 500
max_minibatch_size: 1000
updates_per_batch: 8
normalize_actions: False
......@@ -88,10 +89,10 @@ procgen-ppo:
model:
custom_model: impala_torch_ppg
custom_model_config:
# depths: [32, 64, 64]
# nlatents: 512
depths: [64, 128, 128]
nlatents: 1024
depths: [32, 64, 64]
nlatents: 512
# depths: [64, 128, 128]
# nlatents: 1024
init_normed: True
use_layernorm: False
diff_framestack: True
......
......@@ -47,14 +47,14 @@ procgen-ppo:
# Custom switches
retune_skips: 100000
retune_replay_size: 400000
num_retunes: 14
retune_replay_size: 450000
num_retunes: 13
retune_epochs: 3
standardize_rewards: True
scale_reward: 1.0
return_reset: False
aux_phase_mixed_precision: True
max_time: 7200
max_time: 1000000
adaptive_gamma: False
final_lr: 5.0e-5
......@@ -63,7 +63,7 @@ procgen-ppo:
entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 1000
max_minibatch_size: 2048
updates_per_batch: 8
normalize_actions: False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment