Commit c1c85730 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

mixed precsion aux

parent 850e912b
......@@ -10,6 +10,7 @@ import time
torch, nn = try_import_torch()
import torch.distributions as td
from torch.cuda.amp import autocast, GradScaler
class CustomTorchPolicy(TorchPolicy):
"""Example of a random policy
......@@ -95,6 +96,7 @@ class CustomTorchPolicy(TorchPolicy):
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.make_distr = dist_build(action_space)
self.retunes_completed = 0
self.amp_scaler = GradScaler()
self.update_lr()
......@@ -199,7 +201,10 @@ class CustomTorchPolicy(TorchPolicy):
## Distill with aux head
should_retune = self.retune_selector.update(unroll(obs, ts), mb_returns)
if should_retune:
import time
tnow = time.perf_counter()
self.aux_train()
print("Aux Train %fs" % (time.perf_counter()-tnow))
self.update_gamma(samples)
self.update_lr()
......@@ -273,25 +278,41 @@ class CustomTorchPolicy(TorchPolicy):
else:
obs_in = self.to_tensor(obs)
if not self.config['aux_phase_mixed_precision']:
loss, vf_loss = self._aux_calc_loss(obs_in, target_vf, target_pi)
loss.backward()
vf_loss.backward()
self.aux_optimizer.step()
self.value_optimizer.step()
else:
with autocast():
loss, vf_loss = self._aux_calc_loss(obs_in, target_vf, target_pi)
self.amp_scaler.scale(loss).backward(retain_graph=True)
self.amp_scaler.scale(vf_loss).backward()
self.amp_scaler.step(self.aux_optimizer)
self.amp_scaler.step(self.value_optimizer)
self.amp_scaler.update()
self.aux_optimizer.zero_grad()
self.value_optimizer.zero_grad()
def _aux_calc_loss(self, obs_in, target_vf, target_pi):
vpred, pi_logits = self.model.vf_pi(obs_in, ret_numpy=False, no_grad=False, to_torch=False)
aux_vpred = self.model.aux_value_function()
aux_loss = .5 * torch.mean(torch.pow(aux_vpred - target_vf, 2))
target_pd = self.make_distr(target_pi)
pd = self.make_distr(pi_logits)
pi_loss = td.kl_divergence(target_pd, pd).mean()
loss = pi_loss + aux_loss
loss.backward()
self.aux_optimizer.step()
self.aux_optimizer.zero_grad()
vf_loss = .5 * torch.mean(torch.pow(vpred - target_vf, 2))
vf_loss.backward()
self.value_optimizer.step()
self.value_optimizer.zero_grad()
return loss, vf_loss
def best_reward_model_select(self, samples):
self.timesteps_total += len(samples['dones'])
......@@ -389,13 +410,12 @@ class CustomTorchPolicy(TorchPolicy):
for k, v in self.model.state_dict().items()
}
return weights
@override(TorchPolicy)
def set_weights(self, weights):
self.set_model_weights(weights["current_weights"])
def set_optimizer_state(self, optimizer_state, aux_optimizer_state, value_optimizer_state):
def set_optimizer_state(self, optimizer_state, aux_optimizer_state, value_optimizer_state, amp_scaler_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state)
......@@ -405,6 +425,9 @@ class CustomTorchPolicy(TorchPolicy):
value_optimizer_state = convert_to_torch_tensor(value_optimizer_state, device=self.device)
self.value_optimizer.load_state_dict(value_optimizer_state)
amp_scaler_state = convert_to_torch_tensor(amp_scaler_state, device=self.device)
self.amp_scaler.load_state_dict(amp_scaler_state)
def set_model_weights(self, model_weights):
model_weights = convert_to_torch_tensor(model_weights, device=self.device)
self.model.load_state_dict(model_weights)
\ No newline at end of file
......@@ -196,6 +196,7 @@ def build_trainer(name,
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
state["value_optimizer_state"] = {k: v for k, v in policy.value_optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
if self.train_exec_impl:
state["train_exec_impl"] = (
......@@ -206,7 +207,8 @@ def build_trainer(name,
Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"], state["value_optimizer_state"])
policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"],
state["value_optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
if self.train_exec_impl:
......
......@@ -94,6 +94,7 @@ DEFAULT_CONFIG = with_common_config({
"aux_lr": 5e-4,
"value_lr": 1e-3,
"same_lr_everywhere": False,
"aux_phase_mixed_precision": False,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -8,7 +8,7 @@ procgen-ppo:
time_total_s: 7200
# === Settings for Checkpoints ===
checkpoint_freq: 25
checkpoint_freq: 1
checkpoint_at_end: True
keep_checkpoints_num: 5
......@@ -46,19 +46,20 @@ procgen-ppo:
# Custom switches
skips: 0
n_pi: 18
n_pi: 16
num_retunes: 100
retune_epochs: 3
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 4
augment_buffer: True
scale_reward: 1.0
reset_returns: False
flattened_buffer: True
augment_randint_num: 6 ## Hacky name fix later
augment_randint_num: 3 ## Hacky name fix later
aux_lr: 5.0e-4
value_lr: 1.0e-3
same_lr_everywhere: False
aux_phase_mixed_precision: True
adaptive_gamma: False
final_lr: 5.0e-5
......@@ -83,7 +84,7 @@ procgen-ppo:
# === Settings for Model ===
model:
custom_model: impala_torch_ppg
custom_options:
custom_model_config:
# depths: [16, 32, 32]
# nlatents: 256
depths: [32, 64, 64]
......@@ -93,7 +94,7 @@ procgen-ppo:
diff_framestack: True
num_workers: 7
num_envs_per_worker: 16
num_envs_per_worker: 9
rollout_fragment_length: 256
......
......@@ -70,13 +70,12 @@ class ImpalaCNN(TorchModelV2, nn.Module):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.device = device
depths = model_config['custom_options'].get('depths')
nlatents = model_config['custom_options'].get('nlatents')
init_normed = model_config['custom_options'].get('init_normed')
self.use_layernorm = model_config['custom_options'].get('use_layernorm')
self.diff_framestack = model_config['custom_options'].get('diff_framestack')
depths = model_config['custom_model_config'].get('depths')
nlatents = model_config['custom_model_config'].get('nlatents')
init_normed = model_config['custom_model_config'].get('init_normed')
self.use_layernorm = model_config['custom_model_config'].get('use_layernorm')
self.diff_framestack = model_config['custom_model_config'].get('diff_framestack')
h, w, c = obs_space.shape
if self.diff_framestack:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment