Commit 4f1dd1ac authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

aux sample full rollout

parent 0d488563
......@@ -58,7 +58,10 @@ class CustomTorchPolicy(TorchPolicy):
self.best_rew_tsteps = 0
nw = self.config['num_workers'] if self.config['num_workers'] > 0 else 1
self.nbatch = nw * self.config['num_envs_per_worker'] * self.config['rollout_fragment_length']
nenvs = nw * self.config['num_envs_per_worker']
nsteps = self.config['rollout_fragment_length']
n_pi = self.config['n_pi']
self.nbatch = nenvs * nsteps
self.actual_batch_size = self.nbatch // self.config['updates_per_batch']
self.accumulate_train_batches = int(np.ceil( self.actual_batch_size / self.config['max_minibatch_size'] ))
self.mem_limited_batch_size = self.actual_batch_size // self.accumulate_train_batches
......@@ -66,13 +69,14 @@ class CustomTorchPolicy(TorchPolicy):
print("#################################################")
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################")
self.retune_selector = RetuneSelector(self.nbatch, observation_space, action_space,
skips = self.config['retune_skips'],
replay_size = self.config['retune_replay_size'],
self.retune_selector = RetuneSelector(nenvs, observation_space, action_space,
skips = self.config['skips'],
n_pi = n_pi,
num_retunes = self.config['num_retunes'])
self.exp_replay = np.zeros((self.retune_selector.replay_size, *observation_space.shape), dtype=np.uint8)
self.vtarg_replay = np.zeros((self.retune_selector.replay_size), dtype=np.float32)
replay_shape = (n_pi, nsteps, nenvs)
self.exp_replay = np.empty((*replay_shape, *observation_space.shape), dtype=np.uint8)
self.vtarg_replay = np.empty(replay_shape, dtype=np.float32)
self.save_success = 0
self.target_timesteps = 8_000_000
self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead
......@@ -121,6 +125,11 @@ class CustomTorchPolicy(TorchPolicy):
self.last_dones = mb_dones[-1]
else:
mb_rewards = unroll(samples['rewards'], ts)
# Weird hack that helps in many envs (Yes keep it after normalization)
rew_scale = self.config["scale_reward"]
if rew_scale != 1.0:
mb_rewards *= rew_scale
should_skip_train_step = self.best_reward_model_select(samples)
if should_skip_train_step:
......@@ -181,7 +190,7 @@ class CustomTorchPolicy(TorchPolicy):
lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
## Distill with aux head
should_retune = self.retune_selector.update(obs, returns, self.exp_replay, self.vtarg_replay)
should_retune = self.retune_selector.update(unroll(obs, ts), mb_returns, self.exp_replay, self.vtarg_replay)
if should_retune:
self.aux_train()
self.update_batch_time()
......@@ -231,43 +240,34 @@ class CustomTorchPolicy(TorchPolicy):
for g in self.aux_optimizer.param_groups:
g['lr'] = self.lr
nbatch_train = self.mem_limited_batch_size
aux_nbatch_train = self.config['aux_mbsize']
retune_epochs = self.config['retune_epochs']
replay_size = self.retune_selector.replay_size
replay_pi = np.empty((replay_size, self.retune_selector.ac_space.n), dtype=np.float32)
replay_shape = self.vtarg_replay.shape
replay_pi = np.empty((*replay_shape, self.retune_selector.ac_space.n), dtype=np.float32)
# Store current value function and policy logits
for start in range(0, replay_size, nbatch_train):
end = start + nbatch_train
replay_batch = self.exp_replay[start:end]
_, replay_pi[start:end] = self.model.vf_pi(replay_batch,
ret_numpy=True, no_grad=True, to_torch=True)
for nnpi in range(self.retune_selector.n_pi):
for ne in range(self.retune_selector.nenvs):
_, replay_pi[nnpi, :, ne] = self.model.vf_pi(self.exp_replay[nnpi, :, ne],
ret_numpy=True, no_grad=True, to_torch=True)
optim_count = 0
# Tune vf and pi heads to older predictions with augmented observations
inds = np.arange(len(self.exp_replay))
# Tune vf and pi heads to older predictions with (augmented?) observations
for ep in range(retune_epochs):
np.random.shuffle(inds)
for start in range(0, replay_size, aux_nbatch_train):
end = start + aux_nbatch_train
mbinds = inds[start:end]
optim_count += 1
slices = [self.exp_replay[mbinds],
self.to_tensor(self.vtarg_replay[mbinds]),
self.to_tensor(replay_pi[mbinds])]
self.tune_policy(*slices)
for slices in self.retune_selector.make_minibatches_with_rollouts(self.exp_replay, self.vtarg_replay, replay_pi):
self.tune_policy(slices[0], self.to_tensor(slices[1]), self.to_tensor(slices[2]))
self.retune_selector.retune_done()
def tune_policy(self, obs, target_vf, target_pi):
obs_aug = np.empty(obs.shape, obs.dtype)
aug_idx = np.random.randint(6, size=len(obs))
obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
obs_aug[aug_idx >= 2] = obs[aug_idx >= 2]
obs_aug = self.to_tensor(obs_aug)
if self.config['augment_buffer']:
obs_aug = np.empty(obs.shape, obs.dtype)
aug_idx = np.random.randint(6, size=len(obs))
obs_aug[aug_idx == 0] = pad_and_random_crop(obs[aug_idx == 0], 64, 10)
obs_aug[aug_idx == 1] = random_cutout_color(obs[aug_idx == 1], 10, 30)
obs_aug[aug_idx >= 2] = obs[aug_idx >= 2]
obs_in = self.to_tensor(obs_aug)
else:
obs_in = self.to_tensor(obs)
vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False)
vpred, pi_logits = self.model.vf_pi(obs_in, ret_numpy=False, no_grad=False, to_torch=False)
aux_vpred = self.model.aux_value_function()
vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
aux_loss = .5 * torch.mean(torch.pow(aux_vpred - target_vf, 2))
......
......@@ -2,8 +2,8 @@ import logging
from ray.rllib.agents import with_common_config
from .custom_torch_ppg import CustomTorchPolicy
# from ray.rllib.agents.trainer_template import build_trainer
from .custom_trainer_template import build_trainer
from ray.rllib.agents.trainer_template import build_trainer
# from .custom_trainer_template import build_trainer
logger = logging.getLogger(__name__)
......@@ -71,21 +71,23 @@ DEFAULT_CONFIG = with_common_config({
"use_pytorch": True,
# Custom swithches
"retune_skips": 300000,
"retune_replay_size": 200000,
"num_retunes": 6,
"retune_epochs": 3,
"standardize_rewards": False,
"skips": 0,
"n_pi": 32,
"num_retunes": 100,
"retune_epochs": 6,
"standardize_rewards": True,
"scale_reward": 1.0,
"accumulate_train_batches": 1,
"adaptive_gamma": False,
"final_lr": 1e-4,
"lr_schedule": True,
"final_lr": 2e-4,
"lr_schedule": 'None',
"final_entropy_coeff": 0.002,
"entropy_schedule": True,
"entropy_schedule": False,
"max_minibatch_size": 2048,
"updates_per_batch": 8,
"aux_mbsize": 512,
"aux_mbsize": 4,
"augment_buffer": False,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -6,6 +6,7 @@ from skimage.util import view_as_windows
torch, nn = try_import_torch()
import torch.distributions as td
from functools import partial
import itertools
def _make_categorical(x, ncat, shape):
x = x.reshape((x.shape[0], shape, ncat))
......@@ -98,22 +99,22 @@ class AdaptiveDiscountTuner:
gtarg = horizon_to_gamma(htarg)
self.gamma = self.gamma * self.momentum + gtarg * (1-self.momentum)
return self.gamma
def flatten01(arr):
return arr.reshape(-1, *arr.shape[2:])
class RetuneSelector:
def __init__(self, nbatch, ob_space, ac_space, skips = 800_000, replay_size = 200_000, num_retunes = 5):
self.skips = skips + (-skips) % nbatch
self.replay_size = replay_size + (-replay_size) % nbatch
self.batch_size = nbatch
self.batches_in_replay = self.replay_size // nbatch
def __init__(self, nenvs, ob_space, ac_space, skips = 0, n_pi = 32, num_retunes = 5):
self.skips = skips
self.n_pi = n_pi
self.nenvs = nenvs
self.num_retunes = num_retunes
self.ac_space = ac_space
self.ob_space = ob_space
self.cooldown_counter = self.skips // self.batch_size
self.cooldown_counter = 0
self.replay_index = 0
self.buffer_full = False
def update(self, obs_batch, vtarg_batch, exp_replay, vtarg_replay):
if self.num_retunes == 0:
......@@ -123,26 +124,29 @@ class RetuneSelector:
self.cooldown_counter -= 1
return False
start = self.replay_index * self.batch_size
end = start + self.batch_size
exp_replay[start:end] = obs_batch
vtarg_replay[start:end] = vtarg_batch
exp_replay[self.replay_index] = obs_batch
vtarg_replay[self.replay_index] = vtarg_batch
self.replay_index = (self.replay_index + 1) % self.batches_in_replay
self.buffer_full = self.buffer_full or (self.replay_index == 0)
return self.buffer_full
self.replay_index = (self.replay_index + 1) % self.n_pi
return self.replay_index == 0
def retune_done(self):
self.cooldown_counter = self.skips // self.batch_size
self.cooldown_counter = self.skips
self.num_retunes -= 1
self.replay_index = 0
self.buffer_full = False
def set_num_retunes(self, nr):
self.num_retunes = nr
def make_minibatches_with_rollouts(self, exp_replay, vtarg_replay, presleep_pi, num_rollouts=4):
env_segs = list(itertools.product(range(self.n_pi), range(self.nenvs)))
np.random.shuffle(env_segs)
env_segs = np.array(env_segs)
for idx in range(0, len(env_segs), num_rollouts):
esinds = env_segs[idx:idx+num_rollouts]
mbatch = [flatten01(arr[esinds[:,0], : , esinds[:,1]])
for arr in (exp_replay, vtarg_replay, presleep_pi)]
yield mbatch
class RewardNormalizer(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, gamma=0.99, cliprew=10.0, epsilon=1e-8):
......
......@@ -46,12 +46,14 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
retune_skips: 300000
retune_replay_size: 200000
skips: 6
n_pi: 2
num_retunes: 20
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 1024
aux_mbsize: 4
augment_buffer: False
scale_reward: 1.0
adaptive_gamma: False
final_lr: 2.0e-4
......@@ -78,14 +80,14 @@ procgen-ppo:
model:
custom_model: impala_torch_ppg
custom_options:
# depths: [64, 128, 128]
# nlatents: 1024
depths: [32, 64, 64]
nlatents: 512
depths: [16, 32, 32]
nlatents: 256
# depths: [32, 64, 64]
# nlatents: 512
init_normed: True
use_layernorm: False
num_workers: 7
num_workers: 4
num_envs_per_worker: 16
rollout_fragment_length: 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment