Commit 53941a9d authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

same lr everywhere ppg

parent 5c899dcc
......@@ -50,9 +50,9 @@ class CustomTorchPolicy(TorchPolicy):
network_params = set(self.model.parameters())
aux_optim_params = list(network_params - value_params)
ppo_optim_params = list(network_params - aux_params - value_params)
self.optimizer = torch.optim.Adam(ppo_optim_params, lr=5e-4)
self.aux_optimizer = torch.optim.Adam(aux_optim_params, lr=5e-4)
self.value_optimizer = torch.optim.Adam(value_params, lr=1e-3)
self.optimizer = torch.optim.Adam(ppo_optim_params, lr=self.config['lr'])
self.aux_optimizer = torch.optim.Adam(aux_optim_params, lr=self.config['aux_lr'])
self.value_optimizer = torch.optim.Adam(value_params, lr=self.config['value_lr'])
self.max_reward = self.config['env_config']['return_max']
self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state
self.reward_deque = deque(maxlen=100)
......@@ -96,6 +96,8 @@ class CustomTorchPolicy(TorchPolicy):
self.make_distr = dist_build(action_space)
self.retunes_completed = 0
self.update_lr()
def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device)
......@@ -169,7 +171,6 @@ class CustomTorchPolicy(TorchPolicy):
## Data from config
cliprange, vfcliprange = self.config['clip_param'], self.config['vf_clip_param']
lrnow = self.lr
max_grad_norm = self.config['grad_clip']
ent_coef, vf_coef = self.ent_coef, self.config['vf_loss_coeff']
......@@ -193,7 +194,7 @@ class CustomTorchPolicy(TorchPolicy):
optim_count += 1
apply_grad = (optim_count % self.accumulate_train_batches) == 0
self._batch_train(apply_grad, self.accumulate_train_batches,
lrnow, cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
cliprange, vfcliprange, max_grad_norm, ent_coef, vf_coef, *slices)
## Distill with aux head
should_retune = self.retune_selector.update(unroll(obs, ts), mb_returns)
......@@ -212,12 +213,10 @@ class CustomTorchPolicy(TorchPolicy):
self.batch_end_time = time.time()
def _batch_train(self, apply_grad, num_accumulate,
lr, cliprange, vfcliprange, max_grad_norm,
cliprange, vfcliprange, max_grad_norm,
ent_coef, vf_coef,
obs, returns, actions, values, logp_actions_old, advs):
for g in self.optimizer.param_groups:
g['lr'] = lr
vpred, pi_logits = self.model.vf_pi(obs, ret_numpy=False, no_grad=False, to_torch=False)
pd = self.make_distr(pi_logits)
logp_actions = pd.log_prob(actions[...,None]).squeeze(1)
......@@ -245,8 +244,6 @@ class CustomTorchPolicy(TorchPolicy):
def aux_train(self):
for g in self.aux_optimizer.param_groups:
g['lr'] = self.config['aux_lr']
nbatch_train = self.mem_limited_batch_size
retune_epochs = self.config['retune_epochs']
replay_shape = self.retune_selector.vtarg_replay.shape
......@@ -324,6 +321,14 @@ class CustomTorchPolicy(TorchPolicy):
elif self.config['lr_schedule'] == 'exponential':
self.lr = 0.997 * self.lr
for g in self.optimizer.param_groups:
g['lr'] = self.lr
if self.config['same_lr_everywhere']:
for g in self.value_optimizer.param_groups:
g['lr'] = self.lr
for g in self.aux_optimizer.param_groups:
g['lr'] = self.lr
def update_ent_coef(self):
......@@ -383,32 +388,22 @@ class CustomTorchPolicy(TorchPolicy):
k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items()
}
# weights["optimizer_state"] = {
# k: v
# for k, v in self.optimizer.state_dict().items()
# }
# weights["aux_optimizer_state"] = {
# k: v
# for k, v in self.aux_optimizer.state_dict().items()
# }
# weights["custom_state_vars"] = self.get_custom_state_vars()
return weights
@override(TorchPolicy)
def set_weights(self, weights):
self.set_model_weights(weights["current_weights"])
# self.set_optimizer_state(weights["optimizer_state"])
# self.set_aux_optimizer_state(weights["aux_optimizer_state"])
# self.set_custom_state_vars(weights["custom_state_vars"])
def set_aux_optimizer_state(self, aux_optimizer_state):
def set_optimizer_state(self, optimizer_state, aux_optimizer_state, value_optimizer_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state)
aux_optimizer_state = convert_to_torch_tensor(aux_optimizer_state, device=self.device)
self.aux_optimizer.load_state_dict(aux_optimizer_state)
def set_optimizer_state(self, optimizer_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state)
value_optimizer_state = convert_to_torch_tensor(value_optimizer_state, device=self.device)
self.value_optimizer.load_state_dict(value_optimizer_state)
def set_model_weights(self, model_weights):
model_weights = convert_to_torch_tensor(model_weights, device=self.device)
......
......@@ -206,9 +206,7 @@ def build_trainer(name,
Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"])
policy.set_aux_optimizer_state(state["aux_optimizer_state"])
policy.set_value_optimizer_state(state["value_optimizer_state"])
policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"], state["value_optimizer_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
if self.train_exec_impl:
......
......@@ -92,6 +92,8 @@ DEFAULT_CONFIG = with_common_config({
"flattened_buffer": False,
"augment_randint_num": 6,
"aux_lr": 5e-4,
"value_lr": 1e-3,
"same_lr_everywhere": False,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -45,8 +45,8 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
skips: 0
n_pi: 18
skips: 9
n_pi: 9
num_retunes: 100
retune_epochs: 6
standardize_rewards: True
......@@ -57,6 +57,8 @@ procgen-ppo:
flattened_buffer: True
augment_randint_num: 6 ## Hacky name fix later
aux_lr: 5.0e-4
value_lr: 1.0e-3
same_lr_everywhere: True
adaptive_gamma: False
final_lr: 5.0e-5
......@@ -82,10 +84,10 @@ procgen-ppo:
model:
custom_model: impala_torch_ppg
custom_options:
depths: [16, 32, 32]
nlatents: 256
# depths: [32, 64, 64]
# nlatents: 512
# depths: [16, 32, 32]
# nlatents: 256
depths: [32, 64, 64]
nlatents: 512
init_normed: True
use_layernorm: False
diff_framestack: True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment