Commit c0820e84 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

ppg x2 + remove timekeeping

parent 7ee5a928
......@@ -80,7 +80,7 @@ class CustomTorchPolicy(TorchPolicy):
self.save_success = 0
self.target_timesteps = 8_000_000
self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead
self.max_time = 7200
self.max_time = 100000000
self.maxrewep_lenbuf = deque(maxlen=100)
self.gamma = self.config['gamma']
self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
......@@ -90,6 +90,7 @@ class CustomTorchPolicy(TorchPolicy):
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.make_distr = dist_build(action_space)
self.retunes_completed = 0
def to_tensor(self, arr):
return torch.from_numpy(arr).to(self.device)
......@@ -256,6 +257,7 @@ class CustomTorchPolicy(TorchPolicy):
self.exp_replay.fill(0)
self.vtarg_replay.fill(0)
self.retunes_completed += 1
self.retune_selector.retune_done()
def tune_policy(self, obs, target_vf, target_pi):
......@@ -345,6 +347,7 @@ class CustomTorchPolicy(TorchPolicy):
"best_rew_tsteps": self.best_rew_tsteps,
"best_reward": self.best_reward,
"last_dones": self.last_dones,
"retunes_completed": self.retunes_completed,
}
def set_custom_state_vars(self, custom_state_vars):
......@@ -355,12 +358,13 @@ class CustomTorchPolicy(TorchPolicy):
self.batch_end_time = custom_state_vars["batch_end_time"]
self.gamma = self.adaptive_discount_tuner.gamma = custom_state_vars["gamma"]
self.maxrewep_lenbuf = custom_state_vars["maxrewep_lenbuf"]
self.lr =custom_state_vars["lr"]
self.lr = custom_state_vars["lr"]
self.ent_coef = custom_state_vars["ent_coef"]
self.rewnorm = custom_state_vars["rewnorm"]
self.best_rew_tsteps = custom_state_vars["best_rew_tsteps"]
self.best_reward = custom_state_vars["best_reward"]
self.last_dones = custom_state_vars["last_dones"]
self.retunes_completed = custom_state_vars["retunes_completed"]
@override(TorchPolicy)
def get_weights(self):
......@@ -369,24 +373,24 @@ class CustomTorchPolicy(TorchPolicy):
k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items()
}
weights["optimizer_state"] = {
k: v
for k, v in self.optimizer.state_dict().items()
}
weights["aux_optimizer_state"] = {
k: v
for k, v in self.aux_optimizer.state_dict().items()
}
weights["custom_state_vars"] = self.get_custom_state_vars()
# weights["optimizer_state"] = {
# k: v
# for k, v in self.optimizer.state_dict().items()
# }
# weights["aux_optimizer_state"] = {
# k: v
# for k, v in self.aux_optimizer.state_dict().items()
# }
# weights["custom_state_vars"] = self.get_custom_state_vars()
return weights
@override(TorchPolicy)
def set_weights(self, weights):
self.set_model_weights(weights["current_weights"])
self.set_optimizer_state(weights["optimizer_state"])
self.set_aux_optimizer_state(weights["aux_optimizer_state"])
self.set_custom_state_vars(weights["custom_state_vars"])
# self.set_optimizer_state(weights["optimizer_state"])
# self.set_aux_optimizer_state(weights["aux_optimizer_state"])
# self.set_custom_state_vars(weights["custom_state_vars"])
def set_aux_optimizer_state(self, aux_optimizer_state):
aux_optimizer_state = convert_to_torch_tensor(aux_optimizer_state, device=self.device)
......
......@@ -197,41 +197,41 @@ def build_trainer(name,
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
save_success = False
max_size = 3_700_000_000
if policy.exp_replay.nbytes < max_size:
state["replay_buffer"] = policy.exp_replay
state["buffer_saved"] = 1
policy.save_success = 1
save_success = True
elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
eq = np.all(policy.exp_replay[:,1:,...,:3] == policy.exp_replay[:,:-1,...,-3:], axis=(-3,-2,-1))
non_eq = np.where(1 - eq)
images_non_eq = policy.exp_replay[non_eq]
images_last = policy.exp_replay[:,-1,...,-3:]
images_first = policy.exp_replay[:,0,...,:3]
if policy.exp_replay[:,1:,...,:3].nbytes < max_size:
state["sliced_buffer"] = policy.exp_replay[:,1:,...,:3]
state["buffer_saved"] = 2
policy.save_success = 2
save_success = True
else:
comp = compress(policy.exp_replay[:,1:,...,:3].copy(), level=9)
if getsizeof(comp) < max_size:
state["compressed_buffer"] = comp
state["buffer_saved"] = 3
policy.save_success = 3
save_success = True
if save_success:
state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
# save_success = False
# max_size = 3_700_000_000
# if policy.exp_replay.nbytes < max_size:
# state["replay_buffer"] = policy.exp_replay
# state["buffer_saved"] = 1
# policy.save_success = 1
# save_success = True
# elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
# eq = np.all(policy.exp_replay[:,1:,...,:3] == policy.exp_replay[:,:-1,...,-3:], axis=(-3,-2,-1))
# non_eq = np.where(1 - eq)
# images_non_eq = policy.exp_replay[non_eq]
# images_last = policy.exp_replay[:,-1,...,-3:]
# images_first = policy.exp_replay[:,0,...,:3]
# if policy.exp_replay[:,1:,...,:3].nbytes < max_size:
# state["sliced_buffer"] = policy.exp_replay[:,1:,...,:3]
# state["buffer_saved"] = 2
# policy.save_success = 2
# save_success = True
# else:
# comp = compress(policy.exp_replay[:,1:,...,:3].copy(), level=9)
# if getsizeof(comp) < max_size:
# state["compressed_buffer"] = comp
# state["buffer_saved"] = 3
# policy.save_success = 3
# save_success = True
# if save_success:
# state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
if not save_success:
state["buffer_saved"] = -1
policy.save_success = -1
print("####################### BUFFER SAVE FAILED #########################")
else:
state["vtarg_replay"] = policy.vtarg_replay
state["retune_selector"] = policy.retune_selector
# if not save_success:
# state["buffer_saved"] = -1
# policy.save_success = -1
# print("####################### BUFFER SAVE FAILED #########################")
# else:
# state["vtarg_replay"] = policy.vtarg_replay
# state["retune_selector"] = policy.retune_selector
if self.train_exec_impl:
......@@ -248,25 +248,25 @@ def build_trainer(name,
policy.set_custom_state_vars(state["custom_state_vars"])
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
buffer_saved = state.get("buffer_saved", -1)
policy.save_success = buffer_saved
if buffer_saved == 1:
policy.exp_replay = state["replay_buffer"]
elif buffer_saved > 1:
non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
policy.exp_replay[non_eq] = images_non_eq
policy.exp_replay[:,-1,...,-3:] = images_last
policy.exp_replay[:,0,...,:3] = images_first
if buffer_saved == 2:
policy.exp_replay[:,1:,...,:3] = state["sliced_buffer"]
elif buffer_saved == 3:
ts = policy.exp_replay[:,1:,...,:3].shape
dt = policy.exp_replay.dtype
decomp = decompress(state["compressed_buffer"])
policy.exp_replay[:,1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
if buffer_saved > 0:
policy.vtarg_replay = state["vtarg_replay"]
policy.retune_selector = state["retune_selector"]
# buffer_saved = state.get("buffer_saved", -1)
# policy.save_success = buffer_saved
# if buffer_saved == 1:
# policy.exp_replay = state["replay_buffer"]
# elif buffer_saved > 1:
# non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
# policy.exp_replay[non_eq] = images_non_eq
# policy.exp_replay[:,-1,...,-3:] = images_last
# policy.exp_replay[:,0,...,:3] = images_first
# if buffer_saved == 2:
# policy.exp_replay[:,1:,...,:3] = state["sliced_buffer"]
# elif buffer_saved == 3:
# ts = policy.exp_replay[:,1:,...,:3].shape
# dt = policy.exp_replay.dtype
# decomp = decompress(state["compressed_buffer"])
# policy.exp_replay[:,1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
# if buffer_saved > 0:
# policy.vtarg_replay = state["vtarg_replay"]
# policy.retune_selector = state["retune_selector"]
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
......
......@@ -2,8 +2,8 @@ import logging
from ray.rllib.agents import with_common_config
from .custom_torch_ppg import CustomTorchPolicy
from ray.rllib.agents.trainer_template import build_trainer
# from .custom_trainer_template import build_trainer
# from ray.rllib.agents.trainer_template import build_trainer
from .custom_trainer_template import build_trainer
logger = logging.getLogger(__name__)
......
......@@ -47,9 +47,9 @@ procgen-ppo:
# Custom switches
skips: 0
n_pi: 10
n_pi: 1
num_retunes: 100
retune_epochs: 6
retune_epochs: 3
standardize_rewards: True
aux_mbsize: 4
augment_buffer: False
......@@ -63,8 +63,7 @@ procgen-ppo:
# Memory management, if batch size overflow, batch splitting is done to handle it
max_minibatch_size: 2048
updates_per_batch: 8
updates_per_batch: 8
normalize_actions: False
clip_rewards: null
......@@ -80,15 +79,15 @@ procgen-ppo:
model:
custom_model: impala_torch_ppg
custom_options:
depths: [16, 32, 32]
nlatents: 256
# depths: [32, 64, 64]
# nlatents: 512
# depths: [16, 32, 32]
# nlatents: 256
depths: [32, 64, 64]
nlatents: 512
init_normed: True
use_layernorm: False
num_workers: 7
num_envs_per_worker: 16
num_envs_per_worker: 9
rollout_fragment_length: 256
......
......@@ -6,8 +6,8 @@ set -e
#########################################
# export EXPERIMENT_DEFAULT="experiments/impala-baseline.yaml"
export EXPERIMENT_DEFAULT="experiments/custom-torch-ppo.yaml"
# export EXPERIMENT_DEFAULT="experiments/custom-ppg.yaml"
# export EXPERIMENT_DEFAULT="experiments/custom-torch-ppo.yaml"
export EXPERIMENT_DEFAULT="experiments/custom-ppg.yaml"
export EXPERIMENT=${EXPERIMENT:-$EXPERIMENT_DEFAULT}
if [[ -z $AICROWD_IS_GRADING ]]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment