Commit cca87e73 authored by Chakraborty's avatar Chakraborty
Browse files

merge ppg

parents cbeaecd6 606189ba
......@@ -22,6 +22,8 @@ class CustomTorchPolicy(TorchPolicy):
def __init__(self, observation_space, action_space, config):
self.config = config
self.acion_space = action_space
self.observation_space = observation_space
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dist_class, logit_dim = ModelCatalog.get_action_dist(
......@@ -46,6 +48,10 @@ class CustomTorchPolicy(TorchPolicy):
)
self.framework = "torch"
def init_training(self):
""" Init once only for the policy - Surely there should be a bette way to do this """
aux_params = set(self.model.aux_vf.parameters())
value_params = set(self.model.value_fc.parameters())
network_params = set(self.model.parameters())
......@@ -80,7 +86,7 @@ class CustomTorchPolicy(TorchPolicy):
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################")
replay_shape = (n_pi, nsteps, nenvs)
self.retune_selector = RetuneSelector(nenvs, observation_space, action_space, replay_shape,
self.retune_selector = RetuneSelector(nenvs, self.observation_space, self.action_space, replay_shape,
skips = self.config['skips'],
n_pi = n_pi,
num_retunes = self.config['num_retunes'],
......@@ -93,11 +99,11 @@ class CustomTorchPolicy(TorchPolicy):
self.gamma = self.config['gamma']
self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
self.lr = config['lr']
self.ent_coef = config['entropy_coeff']
self.lr = self.config['lr']
self.ent_coef = self.config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.make_distr = dist_build(action_space)
self.make_distr = dist_build(self.action_space)
self.retunes_completed = 0
self.amp_scaler = GradScaler()
......
......@@ -140,6 +140,9 @@ def build_trainer(name,
if after_init:
after_init(self)
policy = Trainer.get_policy(self)
policy.init_training()
@override(Trainer)
def _train(self):
if self.train_exec_impl:
......@@ -192,11 +195,14 @@ def build_trainer(name,
state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy()
policy = Trainer.get_policy(self)
try:
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
state["value_optimizer_state"] = {k: v for k, v in policy.value_optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
except:
print("################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
state["train_exec_impl"] = (
......@@ -207,9 +213,12 @@ def build_trainer(name,
Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy()
try:
policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"],
state["value_optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
except:
print("################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
......
......@@ -9,6 +9,7 @@ from .utils import *
import time
torch, nn = try_import_torch()
from torch.cuda.amp import autocast, GradScaler
class CustomTorchPolicy(TorchPolicy):
"""Example of a random policy
......@@ -20,6 +21,8 @@ class CustomTorchPolicy(TorchPolicy):
def __init__(self, observation_space, action_space, config):
self.config = config
self.acion_space = action_space
self.observation_space = observation_space
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dist_class, logit_dim = ModelCatalog.get_action_dist(
......@@ -42,7 +45,11 @@ class CustomTorchPolicy(TorchPolicy):
loss=None,
action_distribution_class=dist_class,
)
self.framework = "torch"
def init_training(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
self.max_reward = self.config['env_config']['return_max']
self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state
......@@ -63,24 +70,25 @@ class CustomTorchPolicy(TorchPolicy):
print("#################################################")
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################")
self.retune_selector = RetuneSelector(self.nbatch, observation_space, action_space,
self.retune_selector = RetuneSelector(self.nbatch, self.observation_space, self.action_space,
skips = self.config['retune_skips'],
replay_size = self.config['retune_replay_size'],
num_retunes = self.config['num_retunes'])
self.exp_replay = np.zeros((self.retune_selector.replay_size, *observation_space.shape), dtype=np.uint8)
self.exp_replay = np.empty((self.retune_selector.replay_size, *self.observation_space.shape), dtype=np.uint8)
self.target_timesteps = 8_000_000
self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead
self.max_time = 10000000000000 # ignore timekeeping because spot instances are messing it up
self.max_time = self.config['max_time']
self.maxrewep_lenbuf = deque(maxlen=100)
self.gamma = self.config['gamma']
self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
self.lr = config['lr']
self.ent_coef = config['entropy_coeff']
self.lr = self.config['lr']
self.ent_coef = self.config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.save_success = 0
self.retunes_completed = 0
self.amp_scaler = GradScaler()
def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device)
......@@ -268,7 +276,6 @@ class CustomTorchPolicy(TorchPolicy):
self.to_tensor(replay_pi[mbinds])]
self.tune_policy(apply_grad, *slices, 0.5)
self.exp_replay.fill(0)
self.retunes_completed += 1
self.retune_selector.retune_done()
......@@ -282,6 +289,23 @@ class CustomTorchPolicy(TorchPolicy):
with torch.no_grad():
tpi_log_softmax = nn.functional.log_softmax(target_pi, dim=1)
tpi_softmax = torch.exp(tpi_log_softmax)
if not self.config['aux_phase_mixed_precision']:
loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
else:
with autocast():
loss = self._retune_calc_loss(obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff)
self.amp_scaler.scale(loss).backward()
if apply_grad:
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
self.optimizer.zero_grad()
def _retune_calc_loss(self, obs_aug, target_vf, tpi_softmax, tpi_log_softmax, retune_vf_loss_coeff):
vpred, pi_logits = self.model.vf_pi(obs_aug, ret_numpy=False, no_grad=False, to_torch=False)
pi_log_softmax = nn.functional.log_softmax(pi_logits, dim=1)
pi_loss = torch.mean(torch.sum(tpi_softmax * (tpi_log_softmax - pi_log_softmax) , dim=1)) # kl_div torch 1.3.1 has numerical issues
......@@ -289,11 +313,7 @@ class CustomTorchPolicy(TorchPolicy):
loss = retune_vf_loss_coeff * vf_loss + pi_loss
loss = loss / self.accumulate_train_batches
loss.backward()
if apply_grad:
self.optimizer.step()
self.optimizer.zero_grad()
return loss
def best_reward_model_select(self, samples):
self.timesteps_total += self.nbatch
......@@ -384,24 +404,20 @@ class CustomTorchPolicy(TorchPolicy):
k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items()
}
# weights["optimizer_state"] = {
# k: v
# for k, v in self.optimizer.state_dict().items()
# }
# weights["custom_state_vars"] = self.get_custom_state_vars()
return weights
@override(TorchPolicy)
def set_weights(self, weights):
self.set_model_weights(weights["current_weights"])
# self.set_optimizer_state(weights["optimizer_state"])
# self.set_custom_state_vars(weights["custom_state_vars"])
def set_optimizer_state(self, optimizer_state):
def set_optimizer_state(self, optimizer_state, amp_scaler_state):
optimizer_state = convert_to_torch_tensor(optimizer_state, device=self.device)
self.optimizer.load_state_dict(optimizer_state)
amp_scaler_state = convert_to_torch_tensor(amp_scaler_state, device=self.device)
self.amp_scaler.load_state_dict(amp_scaler_state)
def set_model_weights(self, model_weights):
model_weights = convert_to_torch_tensor(model_weights, device=self.device)
self.model.load_state_dict(model_weights)
\ No newline at end of file
......@@ -140,6 +140,9 @@ def build_trainer(name,
if after_init:
after_init(self)
policy = Trainer.get_policy(self)
policy.init_training()
@override(Trainer)
def _train(self):
if self.train_exec_impl:
......@@ -192,44 +195,12 @@ def build_trainer(name,
state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy()
policy = Trainer.get_policy(self)
try:
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# save_success = False
# max_size = 3_700_000_000
# if policy.exp_replay.nbytes < max_size:
# state["replay_buffer"] = policy.exp_replay
# state["buffer_saved"] = 1
# policy.save_success = 1
# save_success = True
# elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
# eq = np.all(policy.exp_replay[1:,...,:3] == policy.exp_replay[:-1,...,-3:], axis=(-3,-2,-1))
# non_eq = np.where(1 - eq)
# images_non_eq = policy.exp_replay[non_eq]
# images_last = policy.exp_replay[-1,...,-3:]
# images_first = policy.exp_replay[0,...,:3]
# if policy.exp_replay[1:,...,:3].nbytes < max_size:
# state["sliced_buffer"] = policy.exp_replay[1:,...,:3]
# state["buffer_saved"] = 2
# policy.save_success = 2
# save_success = True
# else:
# comp = compress(policy.exp_replay[1:,...,:3].copy(), level=9)
# if getsizeof(comp) < max_size:
# state["compressed_buffer"] = comp
# state["buffer_saved"] = 3
# policy.save_success = 3
# save_success = True
# if save_success:
# state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
# if not save_success:
# state["buffer_saved"] = -1
# policy.save_success = -1
# print("####################### BUFFER SAVE FAILED #########################")
# else:
# state["retune_selector"] = policy.retune_selector
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
except:
print("################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
state["train_exec_impl"] = (
......@@ -240,28 +211,11 @@ def build_trainer(name,
Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"])
try:
policy.set_optimizer_state(state["optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
# buffer_saved = state.get("buffer_saved", -1)
# policy.save_success = buffer_saved
# if buffer_saved == 1:
# policy.exp_replay = state["replay_buffer"]
# elif buffer_saved > 1:
# non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
# policy.exp_replay[non_eq] = images_non_eq
# policy.exp_replay[-1,...,-3:] = images_last
# policy.exp_replay[0,...,:3] = images_first
# if buffer_saved == 2:
# policy.exp_replay[1:,...,:3] = state["sliced_buffer"]
# elif buffer_saved == 3:
# ts = policy.exp_replay[1:,...,:3].shape
# dt = policy.exp_replay.dtype
# decomp = decompress(state["compressed_buffer"])
# policy.exp_replay[1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
# if buffer_saved > 0:
# policy.retune_selector = state["retune_selector"]
except:
print("################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
......
......@@ -88,6 +88,8 @@ DEFAULT_CONFIG = with_common_config({
"updates_per_batch": 8,
"scale_reward": 1.0,
"return_reset": True,
"aux_phase_mixed_precision": False,
"max_time": 100000000,
})
# __sphinx_doc_end__
# yapf: enable
......
......@@ -8,7 +8,7 @@ procgen-ppo:
time_total_s: 7200
# === Settings for Checkpoints ===
checkpoint_freq: 1
checkpoint_freq: 100
checkpoint_at_end: True
keep_checkpoints_num: 5
......@@ -47,7 +47,7 @@ procgen-ppo:
# Custom switches
skips: 2
n_pi: 16
num_retunes: 26
num_retunes: 14
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 4
......@@ -60,7 +60,7 @@ procgen-ppo:
value_lr: 1.0e-3
same_lr_everywhere: False
aux_phase_mixed_precision: True
single_optimizer: False
single_optimizer: True
max_time: 7200
adaptive_gamma: False
......@@ -70,7 +70,7 @@ procgen-ppo:
entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 1500
max_minibatch_size: 1000
updates_per_batch: 8
normalize_actions: False
......@@ -87,8 +87,6 @@ procgen-ppo:
model:
custom_model: impala_torch_ppg
custom_model_config:
# depths: [16, 32, 32]
# nlatents: 256
depths: [32, 64, 64]
nlatents: 512
init_normed: True
......@@ -96,7 +94,7 @@ procgen-ppo:
diff_framestack: True
num_workers: 7
num_envs_per_worker: 9
num_envs_per_worker: 16
rollout_fragment_length: 256
......
......@@ -46,13 +46,15 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
retune_skips: 450000
retune_replay_size: 200000
num_retunes: 11
retune_skips: 100000
retune_replay_size: 400000
num_retunes: 14
retune_epochs: 3
standardize_rewards: True
scale_reward: 1.0
return_reset: False
aux_phase_mixed_precision: True
max_time: 7200
adaptive_gamma: False
final_lr: 5.0e-5
......@@ -61,7 +63,7 @@ procgen-ppo:
entropy_schedule: False
# Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 2048
max_minibatch_size: 1000
updates_per_batch: 8
normalize_actions: False
......@@ -77,12 +79,12 @@ procgen-ppo:
# === Settings for Model ===
model:
custom_model: impala_torch_custom
custom_options:
custom_model_config:
depths: [32, 64, 64]
nlatents: 512
use_layernorm: True
diff_framestack: True
d2rl: True
d2rl: False
num_workers: 7
num_envs_per_worker: 16
......
......@@ -87,11 +87,11 @@ class ImpalaCNN(TorchModelV2, nn.Module):
nn.Module.__init__(self)
self.device = device
depths = model_config['custom_options'].get('depths') or [16, 32, 32]
nlatents = model_config['custom_options'].get('nlatents') or 256
d2rl = model_config['custom_options'].get('d2rl') or False
self.use_layernorm = model_config['custom_options'].get('use_layernorm') or False
self.diff_framestack = model_config['custom_options'].get('diff_framestack') or False
depths = model_config['custom_model_config'].get('depths')
nlatents = model_config['custom_model_config'].get('nlatents')
d2rl = model_config['custom_model_config'].get('d2rl')
self.use_layernorm = model_config['custom_model_config'].get('use_layernorm')
self.diff_framestack = model_config['custom_model_config'].get('diff_framestack')
h, w, c = obs_space.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment