Commit cd5f49a3 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppg flatbuffer

parent e6bcbe81
......@@ -72,7 +72,8 @@ class CustomTorchPolicy(TorchPolicy):
self.retune_selector = RetuneSelector(nenvs, observation_space, action_space,
skips = self.config['skips'],
n_pi = n_pi,
num_retunes = self.config['num_retunes'])
num_retunes = self.config['num_retunes'],
flat_buffer = self.config['flattened_buffer'])
replay_shape = (n_pi, nsteps, nenvs)
self.exp_replay = np.zeros((*replay_shape, *observation_space.shape), dtype=np.uint8)
......@@ -263,7 +264,7 @@ class CustomTorchPolicy(TorchPolicy):
def tune_policy(self, obs, target_vf, target_pi):
if self.config['augment_buffer']:
obs_aug = np.empty(obs.shape, obs.dtype)
aug_idx = np.random.randint(6, size=len(obs))
aug_idx = np.random.randint(self.config['augment_randint_num'], size=len(obs))
obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
obs_aug[aug_idx >= 2] = obs[aug_idx >= 2]
......
......@@ -89,6 +89,8 @@ DEFAULT_CONFIG = with_common_config({
"aux_mbsize": 4,
"augment_buffer": False,
"reset_returns": True,
"flattened_buffer": False,
"augment_randint_num": 6,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -102,9 +102,13 @@ class AdaptiveDiscountTuner:
def flatten01(arr):
return arr.reshape(-1, *arr.shape[2:])
def flatten012(arr):
return arr.reshape(-1, *arr.shape[3:])
class RetuneSelector:
def __init__(self, nenvs, ob_space, ac_space, skips = 0, n_pi = 32, num_retunes = 5):
def __init__(self, nenvs, ob_space, ac_space, skips = 0, n_pi = 32, num_retunes = 5, flat_buffer=False):
self.skips = skips
self.n_pi = n_pi
self.nenvs = nenvs
......@@ -115,6 +119,7 @@ class RetuneSelector:
self.cooldown_counter = 0
self.replay_index = 0
self.flat_buffer = flat_buffer
def update(self, obs_batch, vtarg_batch, exp_replay, vtarg_replay):
if self.num_retunes == 0:
......@@ -137,14 +142,27 @@ class RetuneSelector:
def make_minibatches_with_rollouts(self, exp_replay, vtarg_replay, presleep_pi, num_rollouts=4):
env_segs = list(itertools.product(range(self.n_pi), range(self.nenvs)))
np.random.shuffle(env_segs)
env_segs = np.array(env_segs)
for idx in range(0, len(env_segs), num_rollouts):
esinds = env_segs[idx:idx+num_rollouts]
mbatch = [flatten01(arr[esinds[:,0], : , esinds[:,1]])
for arr in (exp_replay, vtarg_replay, presleep_pi)]
yield mbatch
if not self.flat_buffer:
env_segs = list(itertools.product(range(self.n_pi), range(self.nenvs)))
np.random.shuffle(env_segs)
env_segs = np.array(env_segs)
for idx in range(0, len(env_segs), num_rollouts):
esinds = env_segs[idx:idx+num_rollouts]
mbatch = [flatten01(arr[esinds[:,0], : , esinds[:,1]])
for arr in (exp_replay, vtarg_replay, presleep_pi)]
else:
nsteps = vtarg_replay.shape[1]
buffsize = self.n_pi * nsteps * self.nenvs
inds = np.arange(buffsize)
np.random.shuffle(inds)
batchsize = num_rollouts * nsteps
for start in range(0, buffsize, batchsize):
end = start+batchsize
mbinds = inds[start:end]
mbatch = [flatten012(arr)[mbinds]
for arr in (exp_replay, vtarg_replay, presleep_pi)]
yield mbatch
class RewardNormalizer(object):
......
......@@ -46,19 +46,21 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
skips: 0
skips: 9
n_pi: 9
num_retunes: 100
retune_epochs: 3
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 4
augment_buffer: False
augment_buffer: True
scale_reward: 1.0
reset_returns: True
reset_returns: False
flattened_buffer: True
augment_randint_num: 6 ## Hacky name fix later
adaptive_gamma: False
final_lr: 2.0e-4
lr_schedule: 'None'
final_lr: 5.0e-5
lr_schedule: 'linear'
final_entropy_coeff: 0.002
entropy_schedule: False
......
......@@ -6,8 +6,8 @@ set -e
#########################################
# export EXPERIMENT_DEFAULT="experiments/impala-baseline.yaml"
export EXPERIMENT_DEFAULT="experiments/custom-torch-ppo.yaml"
# export EXPERIMENT_DEFAULT="experiments/custom-ppg.yaml"
# export EXPERIMENT_DEFAULT="experiments/custom-torch-ppo.yaml"
export EXPERIMENT_DEFAULT="experiments/custom-ppg.yaml"
export EXPERIMENT=${EXPERIMENT:-$EXPERIMENT_DEFAULT}
if [[ -z $AICROWD_IS_GRADING ]]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment