Commit ac409306 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

buffer save hack final

parent 4f1dd1ac
......@@ -75,7 +75,7 @@ class CustomTorchPolicy(TorchPolicy):
num_retunes = self.config['num_retunes'])
replay_shape = (n_pi, nsteps, nenvs)
self.exp_replay = np.empty((*replay_shape, *observation_space.shape), dtype=np.uint8)
self.exp_replay = np.zeros((*replay_shape, *observation_space.shape), dtype=np.uint8)
self.vtarg_replay = np.empty(replay_shape, dtype=np.float32)
self.save_success = 0
self.target_timesteps = 8_000_000
......@@ -335,7 +335,7 @@ class CustomTorchPolicy(TorchPolicy):
"best_weights": self.best_weights,
"reward_deque": self.reward_deque,
"batch_end_time": self.batch_end_time,
"retune_selector": self.retune_selector,
# "retune_selector": self.retune_selector,
"gamma": self.gamma,
"maxrewep_lenbuf": self.maxrewep_lenbuf,
"lr": self.lr,
......@@ -352,7 +352,7 @@ class CustomTorchPolicy(TorchPolicy):
self.best_weights = custom_state_vars["best_weights"]
self.reward_deque = custom_state_vars["reward_deque"]
self.batch_end_time = custom_state_vars["batch_end_time"]
self.retune_selector = custom_state_vars["retune_selector"]
# self.retune_selector = custom_state_vars["retune_selector"]
self.gamma = self.adaptive_discount_tuner.gamma = custom_state_vars["gamma"]
self.maxrewep_lenbuf = custom_state_vars["maxrewep_lenbuf"]
self.lr =custom_state_vars["lr"]
......
......@@ -6,10 +6,11 @@ from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
from zlib import compress, decompress
import numpy as np
from zlib import compress, decompress
from sys import getsizeof
logger = logging.getLogger(__name__)
......@@ -191,23 +192,48 @@ def build_trainer(name,
state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy()
policy = Trainer.get_policy(self)
state["vtarg_replay"] = policy.vtarg_replay
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
if getsizeof(policy.exp_replay) < 3_500_000_000:
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
save_success = False
max_size = 3_700_000_000
if policy.exp_replay.nbytes < max_size:
state["replay_buffer"] = policy.exp_replay
policy.save_success = 1
else:
replay_compressed = compress(policy.exp_replay, level=9)
if getsizeof(replay_compressed) < 3_500_000_000:
state["replay_buffer"] = replay_compressed
state["buffer_info"] = [policy.exp_replay.shape, policy.exp_replay.dtype]
policy.save_success = 2
# print("Compression Success", getsizeof(replay_compressed))
state["buffer_saved"] = 1
policy.save_success = 1
save_success = True
elif policy.exp_replay.shape[-1] == 6: # only for frame stack = 2
eq = np.all(policy.exp_replay[:,1:,...,:3] == policy.exp_replay[:,:-1,...,-3:], axis=(-3,-2,-1))
non_eq = np.where(1 - eq)
images_non_eq = policy.exp_replay[non_eq]
images_last = policy.exp_replay[:,-1,...,-3:]
images_first = policy.exp_replay[:,0,...,:3]
if policy.exp_replay[:,1:,...,:3].nbytes < max_size:
state["sliced_buffer"] = policy.exp_replay[:,1:,...,:3]
state["buffer_saved"] = 2
policy.save_success = 2
save_success = True
else:
policy.save_success = -1
# print("Compression Failed", getsizeof(replay_compressed))
comp = compress(policy.exp_replay[:,1:,...,:3].copy(), level=9)
if getsizeof(comp) < max_size:
state["compressed_buffer"] = comp
state["buffer_saved"] = 3
policy.save_success = 3
save_success = True
if save_success:
state["matched_frame_data"] = [non_eq, images_non_eq, images_last, images_first]
if not save_success:
state["buffer_saved"] = -1
policy.save_success = -1
print("####################### BUFFER SAVE FAILED #########################")
else:
state["vtarg_replay"] = policy.vtarg_replay
state["retune_selector"] = policy.retune_selector
if self.train_exec_impl:
state["train_exec_impl"] = (
self.train_exec_impl.shared_metrics.get().save())
......@@ -220,15 +246,28 @@ def build_trainer(name,
policy.set_optimizer_state(state["optimizer_state"])
policy.set_aux_optimizer_state(state["aux_optimizer_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
replay_buffer = state.get("replay_buffer", None)
if replay_buffer is not None:
if isinstance(replay_buffer, np.ndarray):
policy.exp_replay = replay_buffer
else:
buffshape, buffdtype = state["buffer_info"]
policy.exp_replay = np.array(np.frombuffer(decompress(replay_buffer),
buffdtype).reshape(buffshape))
## Ugly hack to save replay buffer because organizers taking forever to give fix for spot instances
buffer_saved = state.get("buffer_saved", -1)
policy.save_success = buffer_saved
if buffer_saved == 1:
policy.exp_replay = state["replay_buffer"]
elif buffer_saved > 1:
non_eq, images_non_eq, images_last, images_first = state["matched_frame_data"]
policy.exp_replay[non_eq] = images_non_eq
policy.exp_replay[:,-1,...,-3:] = images_last
policy.exp_replay[:,0,...,:3] = images_first
if buffer_saved == 2:
policy.exp_replay[:,1:,...,:3] = state["sliced_buffer"]
elif buffer_saved == 3:
ts = policy.exp_replay[:,1:,...,:3].shape
dt = policy.exp_replay.dtype
decomp = decompress(state["compressed_buffer"])
policy.exp_replay[:,1:,...,:3] = np.array(np.frombuffer(decomp, dtype=dt).reshape(ts))
if buffer_saved > 0:
policy.vtarg_replay = state["vtarg_replay"]
policy.retune_selector = state["retune_selector"]
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
state["train_exec_impl"])
......
......@@ -2,8 +2,8 @@ import logging
from ray.rllib.agents import with_common_config
from .custom_torch_ppg import CustomTorchPolicy
from ray.rllib.agents.trainer_template import build_trainer
# from .custom_trainer_template import build_trainer
# from ray.rllib.agents.trainer_template import build_trainer
from .custom_trainer_template import build_trainer
logger = logging.getLogger(__name__)
......
......@@ -9,7 +9,7 @@ procgen-ppo:
# === Settings for Checkpoints ===
checkpoint_freq: 100
checkpoint_freq: 25
checkpoint_at_end: True
keep_checkpoints_num: 5
......@@ -46,18 +46,18 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
skips: 6
n_pi: 2
num_retunes: 20
skips: 0
n_pi: 10
num_retunes: 100
retune_epochs: 6
standardize_rewards: True
aux_mbsize: 4
augment_buffer: False
scale_reward: 1.0
scale_reward: 0.6
adaptive_gamma: False
final_lr: 2.0e-4
lr_schedule: 'linear'
lr_schedule: 'None'
final_entropy_coeff: 0.002
entropy_schedule: False
......@@ -87,7 +87,7 @@ procgen-ppo:
init_normed: True
use_layernorm: False
num_workers: 4
num_workers: 7
num_envs_per_worker: 16
rollout_fragment_length: 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment