Commit 2f2c6b78 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

value pred while compute actions

parent fcd5a7b5
......@@ -43,6 +43,7 @@ class CustomTorchPolicy(TorchPolicy):
loss=None,
action_distribution_class=dist_class,
)
self.framework = "torch"
aux_params = set(self.model.aux_vf.parameters())
value_params = set(self.model.value_fc.parameters())
......@@ -92,11 +93,16 @@ class CustomTorchPolicy(TorchPolicy):
self.ent_coef = config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.make_distr = dist_build(action_space)
# self.make_distr = dist_build(action_space)
self.make_distr = dist_class
self.retunes_completed = 0
def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device)
@override(TorchPolicy)
def extra_action_out(self, input_dict, state_batches, model, action_dist):
return {'values': model._value.tolist()}
@override(TorchPolicy)
def learn_on_batch(self, samples):
......@@ -145,10 +151,7 @@ class CustomTorchPolicy(TorchPolicy):
## Value prediction
next_obs = unroll(samples['new_obs'], ts)[-1]
last_values, _ = self.model.vf_pi(next_obs, ret_numpy=True, no_grad=True, to_torch=True)
values = np.empty((nbatch,), dtype=np.float32)
for start in range(0, nbatch, nbatch_train): # Causes OOM up if trying to do all at once
end = start + nbatch_train
values[start:end], _ = self.model.vf_pi(samples['obs'][start:end], ret_numpy=True, no_grad=True, to_torch=True)
values = samples['values']
## GAE
mb_values = unroll(values, ts)
......@@ -171,7 +174,7 @@ class CustomTorchPolicy(TorchPolicy):
max_grad_norm = self.config['grad_clip']
ent_coef, vf_coef = self.ent_coef, self.config['vf_loss_coeff']
neglogpacs = -samples['action_logp'] ## np.isclose seems to be True always, otherwise compute again if needed
logp_actions = samples['action_logp'] ## np.isclose seems to be True always, otherwise compute again if needed
noptepochs = self.config['num_sgd_iter']
actions = samples['actions']
returns = roll(mb_returns)
......@@ -187,7 +190,7 @@ class CustomTorchPolicy(TorchPolicy):
for start in range(0, nbatch, nbatch_train):
end = start + nbatch_train
mbinds = inds[start:end]
slices = (self.to_tensor(arr[mbinds]) for arr in (obs, returns, actions, values, neglogpacs, normalized_advs))
slices = (self.to_tensor(arr[mbinds]) for arr in (obs, returns, actions, values, logp_actions, normalized_advs))
optim_count += 1
apply_grad = (optim_count % self.accumulate_train_batches) == 0
self._batch_train(apply_grad, self.accumulate_train_batches,
......@@ -212,18 +215,18 @@ class CustomTorchPolicy(TorchPolicy):
def _batch_train(self, apply_grad, num_accumulate,
lr, cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, neglogpac_old, advs):
obs, returns, actions, values, logp_actions_old, advs):
for g in self.optimizer.param_groups:
g['lr'] = lr
vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False)
pd = self.make_distr(pi_logits)
neglogpac = -pd.log_prob(actions[...,None]).squeeze(1)
logp_actions = pd.logp(actions[...,None]).squeeze(1)
entropy = torch.mean(pd.entropy())
vf_loss = .5 * torch.mean(torch.pow((vpred - returns), 2)) * vf_coef
ratio = torch.exp(neglogpac_old - neglogpac)
ratio = torch.exp(logp_actions - logp_actions_old)
pg_losses1 = -advs * ratio
pg_losses2 = -advs * torch.clamp(ratio, 1-cliprange, 1+cliprange)
pg_loss = torch.mean(torch.max(pg_losses1, pg_losses2))
......
......@@ -195,44 +195,7 @@ def build_trainer(name,
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# save_success = False
# max_size = 3_700_000_000
# if policy.exp_replay.nbytes < max_size:
# state["replay_buffer"] = policy.exp_replay
# state["buffer_saved"] = 1
# policy.save_success = 1
# save_success = True
# elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
# eq = np.all(policy.exp_replay[:,1:,...,:3] == policy.exp_replay[:,:-1,...,-3:], axis=(-3,-2,-1))
# non_eq = np.where(1 - eq)
# images_non_eq = policy.exp_replay[non_eq]
# images_last = policy.exp_replay[:,-1,...,-3:]
# images_first = policy.exp_replay[:,0,...,:3]
# if policy.exp_replay[:,1:,...,:3].nbytes < max_size:
# state["sliced_buffer"] = policy.exp_replay[:,1:,...,:3]
# state["buffer_saved"] = 2
# policy.save_success = 2
# save_success = True
# else:
# comp = compress(policy.exp_replay[:,1:,...,:3].copy(), level=9)
# if getsizeof(comp) < max_size:
# state["compressed_buffer"] = comp
# state["buffer_saved"] = 3
# policy.save_success = 3
# save_success = True
# if save_success:
# state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
# if not save_success:
# state["buffer_saved"] = -1
# policy.save_success = -1
# print("####################### BUFFER SAVE FAILED #########################")
# else:
# state["vtarg_replay"] = policy.vtarg_replay
# state["retune_selector"] = policy.retune_selector
state["value_optimizer_state"] = {k: v for k, v in policy.value_optimizer.state_dict().items()}
if self.train_exec_impl:
state["train_exec_impl"] = (
......@@ -245,29 +208,9 @@ def build_trainer(name,
self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"])
policy.set_aux_optimizer_state(state["aux_optimizer_state"])
policy.set_value_optimizer_state(state["value_optimizer_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# buffer_saved = state.get("buffer_saved", -1)
# policy.save_success = buffer_saved
# if buffer_saved == 1:
# policy.exp_replay = state["replay_buffer"]
# elif buffer_saved > 1:
# non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
# policy.exp_replay[non_eq] = images_non_eq
# policy.exp_replay[:,-1,...,-3:] = images_last
# policy.exp_replay[:,0,...,:3] = images_first
# if buffer_saved == 2:
# policy.exp_replay[:,1:,...,:3] = state["sliced_buffer"]
# elif buffer_saved == 3:
# ts = policy.exp_replay[:,1:,...,:3].shape
# dt = policy.exp_replay.dtype
# decomp = decompress(state["compressed_buffer"])
# policy.exp_replay[:,1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
# if buffer_saved > 0:
# policy.vtarg_replay = state["vtarg_replay"]
# policy.retune_selector = state["retune_selector"]
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
state["train_exec_impl"])
......
......@@ -85,6 +85,11 @@ class CustomTorchPolicy(TorchPolicy):
def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device)
@override(TorchPolicy)
def extra_action_out(self, input_dict, state_batches, model, action_dist):
return {'values': model._value.tolist()}
@override(TorchPolicy)
def learn_on_batch(self, samples):
......@@ -136,10 +141,7 @@ class CustomTorchPolicy(TorchPolicy):
## Value prediction
next_obs = unroll(samples['new_obs'], ts)[-1]
last_values, _ = self.model.vf_pi(next_obs, ret_numpy=True, no_grad=True, to_torch=True)
values = np.empty((nbatch,), dtype=np.float32)
for start in range(0, nbatch, nbatch_train): # Causes OOM up if trying to do all at once
end = start + nbatch_train
values[start:end], _ = self.model.vf_pi(samples['obs'][start:end], ret_numpy=True, no_grad=True, to_torch=True)
values = samples['values']
## GAE
mb_values = unroll(values, ts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment