Commit 669eed64 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppg fixes

parent cd5f49a3
......@@ -43,10 +43,15 @@ class CustomTorchPolicy(TorchPolicy):
loss=None,
action_distribution_class=dist_class,
)
self.framework = "torch"
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
self.aux_optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
aux_params = set(self.model.aux_vf.parameters())
value_params = set(self.model.value_fc.parameters())
network_params = set(self.model.parameters())
aux_optim_params = list(network_params - value_params)
ppo_optim_params = list(network_params - aux_params - value_params)
self.optimizer = torch.optim.Adam(ppo_optim_params, lr=5e-4)
self.aux_optimizer = torch.optim.Adam(aux_optim_params, lr=5e-4)
self.value_optimizer = torch.optim.Adam(value_params, lr=1e-3)
self.max_reward = self.config['env_config']['return_max']
self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state
self.reward_deque = deque(maxlen=100)
......@@ -69,15 +74,12 @@ class CustomTorchPolicy(TorchPolicy):
print("#################################################")
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################")
self.retune_selector = RetuneSelector(nenvs, observation_space, action_space,
replay_shape = (n_pi, nsteps, nenvs)
self.retune_selector = RetuneSelector(nenvs, observation_space, action_space, replay_shape,
skips = self.config['skips'],
n_pi = n_pi,
num_retunes = self.config['num_retunes'],
flat_buffer = self.config['flattened_buffer'])
replay_shape = (n_pi, nsteps, nenvs)
self.exp_replay = np.zeros((*replay_shape, *observation_space.shape), dtype=np.uint8)
self.vtarg_replay = np.empty(replay_shape, dtype=np.float32)
self.save_success = 0
self.target_timesteps = 8_000_000
self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead
......@@ -192,7 +194,7 @@ class CustomTorchPolicy(TorchPolicy):
lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
## Distill with aux head
should_retune = self.retune_selector.update(unroll(obs, ts), mb_returns, self.exp_replay, self.vtarg_replay)
should_retune = self.retune_selector.update(unroll(obs, ts), mb_returns)
if should_retune:
self.aux_train()
self.update_batch_time()
......@@ -221,43 +223,45 @@ class CustomTorchPolicy(TorchPolicy):
neglogpac = -pd.log_prob(actions[...,None]).squeeze(1)
entropy = torch.mean(pd.entropy())
vf_loss = .5 * torch.mean(torch.pow((vpred - returns), 2))
vf_loss = .5 * torch.mean(torch.pow((vpred - returns), 2)) * vf_coef
ratio = torch.exp(neglogpac_old - neglogpac)
pg_losses1 = -advs * ratio
pg_losses2 = -advs * torch.clamp(ratio, 1-cliprange, 1+cliprange)
pg_loss = torch.mean(torch.max(pg_losses1, pg_losses2))
loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef
loss = pg_loss - entropy * ent_coef
loss = loss / num_accumulate
vf_loss = vf_loss / num_accumulate
loss.backward()
vf_loss.backward()
if apply_grad:
self.optimizer.step()
self.value_optimizer.step()
self.optimizer.zero_grad()
self.value_optimizer.zero_grad()
def aux_train(self):
for g in self.aux_optimizer.param_groups:
g['lr'] = self.lr
g['lr'] = self.config['aux_lr']
nbatch_train = self.mem_limited_batch_size
retune_epochs = self.config['retune_epochs']
replay_shape = self.vtarg_replay.shape
replay_shape = self.retune_selector.vtarg_replay.shape
replay_pi = np.empty((*replay_shape, self.retune_selector.ac_space.n), dtype=np.float32)
for nnpi in range(self.retune_selector.n_pi):
for ne in range(self.retune_selector.nenvs):
_, replay_pi[nnpi, :, ne] = self.model.vf_pi(self.exp_replay[nnpi, :, ne],
_, replay_pi[nnpi, :, ne] = self.model.vf_pi(self.retune_selector.exp_replay[nnpi, :, ne],
ret_numpy=True, no_grad=True, to_torch=True)
# Tune vf and pi heads to older predictions with (augmented?) observations
for ep in range(retune_epochs):
for slices in self.retune_selector.make_minibatches_with_rollouts(self.exp_replay, self.vtarg_replay, replay_pi):
for slices in self.retune_selector.make_minibatches(replay_pi):
self.tune_policy(slices[0], self.to_tensor(slices[1]), self.to_tensor(slices[2]))
self.exp_replay.fill(0)
self.vtarg_replay.fill(0)
self.retunes_completed += 1
self.retune_selector.retune_done()
......@@ -274,19 +278,24 @@ class CustomTorchPolicy(TorchPolicy):
vpred, pi_logits = self.model.vf_pi(obs_in, ret_numpy=False, no_grad=False, to_torch=False)
aux_vpred = self.model.aux_value_function()
vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
aux_loss = .5 * torch.mean(torch.pow(aux_vpred - target_vf, 2))
target_pd = self.make_distr(target_pi)
pd = self.make_distr(pi_logits)
pi_loss = td.kl_divergence(target_pd, pd).mean()
loss = vf_loss + pi_loss + aux_loss
loss = pi_loss + aux_loss
loss.backward()
self.aux_optimizer.step()
self.aux_optimizer.zero_grad()
vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
vf_loss.backward()
self.value_optimizer.step()
self.value_optimizer.zero_grad()
def best_reward_model_select(self, samples):
self.timesteps_total += len(samples['dones'])
......
......@@ -91,6 +91,7 @@ DEFAULT_CONFIG = with_common_config({
"reset_returns": True,
"flattened_buffer": False,
"augment_randint_num": 6,
"aux_lr": 5e-4,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -108,11 +108,14 @@ def flatten012(arr):
class RetuneSelector:
def __init__(self, nenvs, ob_space, ac_space, skips = 0, n_pi = 32, num_retunes = 5, flat_buffer=False):
def __init__(self, nenvs, ob_space, ac_space, replay_shape, skips = 0, n_pi = 32, num_retunes = 5, flat_buffer=False):
self.skips = skips
self.n_pi = n_pi
self.nenvs = nenvs
self.exp_replay = np.zeros((*replay_shape, *ob_space.shape), dtype=np.uint8)
self.vtarg_replay = np.empty(replay_shape, dtype=np.float32)
self.num_retunes = num_retunes
self.ac_space = ac_space
self.ob_space = ob_space
......@@ -121,7 +124,7 @@ class RetuneSelector:
self.replay_index = 0
self.flat_buffer = flat_buffer
def update(self, obs_batch, vtarg_batch, exp_replay, vtarg_replay):
def update(self, obs_batch, vtarg_batch):
if self.num_retunes == 0:
return False
......@@ -129,8 +132,8 @@ class RetuneSelector:
self.cooldown_counter -= 1
return False
exp_replay[self.replay_index] = obs_batch
vtarg_replay[self.replay_index] = vtarg_batch
self.exp_replay[self.replay_index] = obs_batch.copy()
self.vtarg_replay[self.replay_index] = vtarg_batch.copy()
self.replay_index = (self.replay_index + 1) % self.n_pi
return self.replay_index == 0
......@@ -141,7 +144,7 @@ class RetuneSelector:
self.replay_index = 0
def make_minibatches_with_rollouts(self, exp_replay, vtarg_replay, presleep_pi, num_rollouts=4):
def make_minibatches(self, presleep_pi, num_rollouts=4):
if not self.flat_buffer:
env_segs = list(itertools.product(range(self.n_pi), range(self.nenvs)))
np.random.shuffle(env_segs)
......@@ -149,9 +152,9 @@ class RetuneSelector:
for idx in range(0, len(env_segs), num_rollouts):
esinds = env_segs[idx:idx+num_rollouts]
mbatch = [flatten01(arr[esinds[:,0], : , esinds[:,1]])
for arr in (exp_replay, vtarg_replay, presleep_pi)]
for arr in (self.exp_replay, self.vtarg_replay, presleep_pi)]
else:
nsteps = vtarg_replay.shape[1]
nsteps = self.vtarg_replay.shape[1]
buffsize = self.n_pi * nsteps * self.nenvs
inds = np.arange(buffsize)
np.random.shuffle(inds)
......@@ -160,7 +163,7 @@ class RetuneSelector:
end = start+batchsize
mbinds = inds[start:end]
mbatch = [flatten012(arr)[mbinds]
for arr in (exp_replay, vtarg_replay, presleep_pi)]
for arr in (self.exp_replay, self.vtarg_replay, presleep_pi)]
yield mbatch
......
......@@ -7,7 +7,6 @@ procgen-ppo:
timesteps_total: 8000000
time_total_s: 7200
# === Settings for Checkpoints ===
checkpoint_freq: 25
checkpoint_at_end: True
......@@ -16,7 +15,7 @@ procgen-ppo:
config:
# === Settings for the Procgen Environment ===
env_config:
env_name: coinrun
env_name: miner
num_levels: 0
start_level: 0
paint_vel_info: False
......@@ -57,6 +56,7 @@ procgen-ppo:
reset_returns: False
flattened_buffer: True
augment_randint_num: 6 ## Hacky name fix later
aux_lr: 5.0e-4
adaptive_gamma: False
final_lr: 5.0e-5
......@@ -88,6 +88,7 @@ procgen-ppo:
nlatents: 512
init_normed: True
use_layernorm: False
diff_framestack: True
num_workers: 7
num_envs_per_worker: 16
......
......@@ -75,9 +75,13 @@ class ImpalaCNN(TorchModelV2, nn.Module):
depths = model_config['custom_options'].get('depths') or [16, 32, 32]
nlatents = model_config['custom_options'].get('nlatents') or 256
init_normed = model_config['custom_options'].get('init_normed') or False
self.use_layernorm = model_config['custom_options'].get('use_layernorm') or True
self.use_layernorm = model_config['custom_options'].get('use_layernorm') or False
self.diff_framestack = model_config['custom_options'].get('diff_framestack') or False
h, w, c = obs_space.shape
if self.diff_framestack:
assert c == 6, "diff_framestack is only for frame_stack = 2"
c = 9
shape = (c, h, w)
conv_seqs = []
......@@ -112,6 +116,8 @@ class ImpalaCNN(TorchModelV2, nn.Module):
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
x = input_dict["obs"].float()
if self.diff_framestack:
x = torch.cat([x, x[...,:-3] - x[...,-3:]], dim=3) # only works for framestack 2 for now
x = x / 255.0 # scale to 0-1
x = x.permute(0, 3, 1, 2) # NHWC => NCHW
x = self.conv_seqs[0](x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment