Commit aab9e125 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

save optimizer state

parent a66420c7
......@@ -137,7 +137,7 @@ class CustomTorchPolicy(TorchPolicy):
next_obs = unroll(samples['new_obs'], ts)[-1]
last_values, _ = self.model.vf_pi(next_obs, ret_numpy=True, no_grad=True, to_torch=True)
values = np.empty((nbatch,), dtype=np.float32)
for start in range(0, nbatch, nbatch_train): # Causes OOM up if trying to do all at once (TODO: Try bigger than nbatch_train)
for start in range(0, nbatch, nbatch_train): # Causes OOM up if trying to do all at once
end = start + nbatch_train
values[start:end], _ = self.model.vf_pi(samples['obs'][start:end], ret_numpy=True, no_grad=True, to_torch=True)
......@@ -190,35 +190,6 @@ class CustomTorchPolicy(TorchPolicy):
self._batch_train(apply_grad, self.accumulate_train_batches,
lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
# actions = samples['actions']
# old_memdata = nbatch, self.actual_batch_size_new, nbatch_train, self.accumulate_train_batches
# old_data = obs, mb_returns, actions, mb_values, neglogpacs
# old_memdata, new_data = self.smart_frameskip(ts, old_memdata, old_data)
# nbatch_new, actual_batch_size_new, nbatch_train_new, num_acc_new = old_memdata
# obs_new, returns_new, actions_new, values_new, neglogpacs_new = new_data
# ## Train multiple epochs
# optim_count = 0
# inds = np.arange(nbatch_new)
# for _ in range(noptepochs):
# np.random.shuffle(inds)
# normalized_advs = returns_new - values_new
# # Can do this because actual_batch_size is a multiple of mem_limited_batch_size
# for start in range(0, nbatch_new, actual_batch_size_new):
# end = start + actual_batch_size_new
# mbinds = inds[start:end]
# advs_batch = normalized_advs[mbinds].copy()
# normalized_advs[mbinds] = (advs_batch - np.mean(advs_batch)) / (np.std(advs_batch) + 1e-8)
# for start in range(0, nbatch_new, nbatch_train_new):
# end = start + nbatch_train_new
# mbinds = inds[start:end]
# slices = (self.to_tensor(arr[mbinds]) for arr in (obs_new, returns_new, actions_new, values_new,
# neglogpacs_new, normalized_advs))
# optim_count += 1
# apply_grad = (optim_count % num_acc_new) == 0
# self._batch_train(apply_grad, num_acc_new, lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
self.update_gamma(samples)
self.update_lr()
self.update_ent_coef()
......@@ -226,9 +197,6 @@ class CustomTorchPolicy(TorchPolicy):
self.update_batch_time()
return {}
# def smart_frameskip(ts, old_memvals, old_data):
def update_batch_time(self):
self.time_elapsed += time.time() - self.batch_end_time
self.batch_end_time = time.time()
......@@ -346,8 +314,9 @@ class CustomTorchPolicy(TorchPolicy):
final_val=self.config['final_lr'],
current_steps=self.timesteps_total,
total_steps=self.target_timesteps)
elif self.config['lr_schedule'] == 'exponential':
self.lr = 0.997 * self.lr
self.lr = 0.997 * self.lr
def update_ent_coef(self):
......@@ -407,6 +376,10 @@ class CustomTorchPolicy(TorchPolicy):
k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items()
}
weights["optimizer_state"] = {
k: v
for k, v in self.optimizer.state_dict().items()
}
weights["custom_state_vars"] = self.get_custom_state_vars()
return weights
......@@ -414,8 +387,13 @@ class CustomTorchPolicy(TorchPolicy):
@override(TorchPolicy)
def set_weights(self, weights):
self.set_model_weights(weights["current_weights"])
self.set_optimizer_state(weights["optimizer_state"])
self.set_custom_state_vars(weights["custom_state_vars"])
def set_optimizer_state(self, optimizer_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state)
def set_model_weights(self, model_weights):
model_weights = convert_to_torch_tensor(model_weights, device=self.device)
self.model.load_state_dict(model_weights)
\ No newline at end of file
......@@ -135,4 +135,8 @@ class CustomCallbacks(DefaultCallbacks):
result['current_gamma'] = trainer_policy.gamma
result['best_reward'] = trainer_policy.best_reward
result['best_rew_tsteps'] = trainer_policy.best_rew_tsteps
result['rnorm_var'] = trainer_policy.rewnorm.ret_rms.var
result['rnorm_mean'] = trainer_policy.rewnorm.ret_rms.mean
......@@ -53,8 +53,8 @@ procgen-ppo:
standardize_rewards: True
adaptive_gamma: False
final_lr: 2.0e-4
lr_schedule: 'exponential'
final_lr: 3.0e-4
lr_schedule: 'linear'
final_entropy_coeff: 0.002
entropy_schedule: False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment