Commit 41bef706 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppo amp

parent df3d29a1
...@@ -9,6 +9,7 @@ from .utils import * ...@@ -9,6 +9,7 @@ from .utils import *
import time import time
torch, nn = try_import_torch() torch, nn = try_import_torch()
from torch.cuda.amp import autocast, GradScaler
class CustomTorchPolicy(TorchPolicy): class CustomTorchPolicy(TorchPolicy):
"""Example of a random policy """Example of a random policy
...@@ -81,6 +82,7 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -81,6 +82,7 @@ class CustomTorchPolicy(TorchPolicy):
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],)) self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.save_success = 0 self.save_success = 0
self.retunes_completed = 0 self.retunes_completed = 0
self.amp_scaler = GradScaler()
def to_tensor(self, arr): def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device) return torch.from_numpy(arr).to(self.device)
...@@ -282,6 +284,23 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -282,6 +284,23 @@ class CustomTorchPolicy(TorchPolicy):
with torch.no_grad(): with torch.no_grad():
tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1) tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1)
tpi_softmax = torch.exp(tpi_log_softmax) tpi_softmax = torch.exp(tpi_log_softmax)
if not self.config['aux_phase_mixed_precision']:
loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
else:
with autocast():
loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
self.amp_scaler.scale(loss).backward()
if apply_grad:
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
self.optimizer.zero_grad()
def _retune_calc_loss(self, obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff):
vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False) vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False)
pi_log_softmax = nn.functional.log_softmax(pi_logits, dim=1) pi_log_softmax = nn.functional.log_softmax(pi_logits, dim=1)
pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues
...@@ -289,11 +308,7 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -289,11 +308,7 @@ class CustomTorchPolicy(TorchPolicy):
loss = retune_vf_loss_coeff * vf_loss + pi_loss loss = retune_vf_loss_coeff * vf_loss + pi_loss
loss = loss / self.accumulate_train_batches loss = loss / self.accumulate_train_batches
return loss
loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
def best_reward_model_select(self, samples): def best_reward_model_select(self, samples):
self.timesteps_total += self.nbatch self.timesteps_total += self.nbatch
...@@ -384,24 +399,20 @@ class CustomTorchPolicy(TorchPolicy): ...@@ -384,24 +399,20 @@ class CustomTorchPolicy(TorchPolicy):
k: v.cpu().detach().numpy() k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items() for k, v in self.model.state_dict().items()
} }
# weights["optimizer_state"] = {
# k: v
# for k, v in self.optimizer.state_dict().items()
# }
# weights["custom_state_vars"] = self.get_custom_state_vars()
return weights return weights
@override(TorchPolicy) @override(TorchPolicy)
def set_weights(self, weights): def set_weights(self, weights):
self.set_model_weights(weights["current_weights"]) self.set_model_weights(weights["current_weights"])
# self.set_optimizer_state(weights["optimizer_state"])
# self.set_custom_state_vars(weights["custom_state_vars"])
def set_optimizer_state(self, optimizer_state): def set_optimizer_state(self, optimizer_state, amp_scaler_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device) optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state) self.optimizer.load_state_dict(optimizer_state)
amp_scaler_state = convert_to_torch_tensor(amp_scaler_state, device=self.device)
self.amp_scaler.load_state_dict(amp_scaler_state)
def set_model_weights(self, model_weights): def set_model_weights(self, model_weights):
model_weights = convert_to_torch_tensor(model_weights, device=self.device) model_weights = convert_to_torch_tensor(model_weights, device=self.device)
self.model.load_state_dict(model_weights) self.model.load_state_dict(model_weights)
\ No newline at end of file
...@@ -194,43 +194,8 @@ def build_trainer(name, ...@@ -194,43 +194,8 @@ def build_trainer(name,
policy = Trainer.get_policy(self) policy = Trainer.get_policy(self)
state["custom_state_vars"] = policy.get_custom_state_vars() state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()} state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# save_success = False
# max_size = 3_700_000_000
# if policy.exp_replay.nbytes < max_size:
# state["replay_buffer"] = policy.exp_replay
# state["buffer_saved"] = 1
# policy.save_success = 1
# save_success = True
# elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
# eq = np.all(policy.exp_replay[1:,...,:3] == policy.exp_replay[:-1,...,-3:], axis=(-3,-2,-1))
# non_eq = np.where(1 - eq)
# images_non_eq = policy.exp_replay[non_eq]
# images_last = policy.exp_replay[-1,...,-3:]
# images_first = policy.exp_replay[0,...,:3]
# if policy.exp_replay[1:,...,:3].nbytes < max_size:
# state["sliced_buffer"] = policy.exp_replay[1:,...,:3]
# state["buffer_saved"] = 2
# policy.save_success = 2
# save_success = True
# else:
# comp = compress(policy.exp_replay[1:,...,:3].copy(), level=9)
# if getsizeof(comp) < max_size:
# state["compressed_buffer"] = comp
# state["buffer_saved"] = 3
# policy.save_success = 3
# save_success = True
# if save_success:
# state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
# if not save_success:
# state["buffer_saved"] = -1
# policy.save_success = -1
# print("####################### BUFFER SAVE FAILED #########################")
# else:
# state["retune_selector"] = policy.retune_selector
if self.train_exec_impl: if self.train_exec_impl:
state["train_exec_impl"] = ( state["train_exec_impl"] = (
self.train_exec_impl.shared_metrics.get().save()) self.train_exec_impl.shared_metrics.get().save())
...@@ -240,28 +205,8 @@ def build_trainer(name, ...@@ -240,28 +205,8 @@ def build_trainer(name,
Trainer.__setstate__(self, state) Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self) policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy() self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"]) policy.set_optimizer_state(state["optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"]) policy.set_custom_state_vars(state["custom_state_vars"])
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# buffer_saved = state.get("buffer_saved", -1)
# policy.save_success = buffer_saved
# if buffer_saved == 1:
# policy.exp_replay = state["replay_buffer"]
# elif buffer_saved > 1:
# non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
# policy.exp_replay[non_eq] = images_non_eq
# policy.exp_replay[-1,...,-3:] = images_last
# policy.exp_replay[0,...,:3] = images_first
# if buffer_saved == 2:
# policy.exp_replay[1:,...,:3] = state["sliced_buffer"]
# elif buffer_saved == 3:
# ts = policy.exp_replay[1:,...,:3].shape
# dt = policy.exp_replay.dtype
# decomp = decompress(state["compressed_buffer"])
# policy.exp_replay[1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
# if buffer_saved > 0:
# policy.retune_selector = state["retune_selector"]
if self.train_exec_impl: if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore( self.train_exec_impl.shared_metrics.get().restore(
......
...@@ -88,6 +88,7 @@ DEFAULT_CONFIG = with_common_config({ ...@@ -88,6 +88,7 @@ DEFAULT_CONFIG = with_common_config({
"updates_per_batch": 8, "updates_per_batch": 8,
"scale_reward": 1.0, "scale_reward": 1.0,
"return_reset": True, "return_reset": True,
"aux_phase_mixed_precision": False,
}) })
# __sphinx_doc_end__ # __sphinx_doc_end__
# yapf: enable # yapf: enable
......
...@@ -53,6 +53,7 @@ procgen-ppo: ...@@ -53,6 +53,7 @@ procgen-ppo:
standardize_rewards: True standardize_rewards: True
scale_reward: 1.0 scale_reward: 1.0
return_reset: False return_reset: False
aux_phase_mixed_precision: True
adaptive_gamma: False adaptive_gamma: False
final_lr: 5.0e-5 final_lr: 5.0e-5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment