Commit 7106fa78 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

init for training process only

parent 41bef706
......@@ -22,6 +22,8 @@ class CustomTorchPolicy(TorchPolicy):
def __init__(self, observation_space, action_space, config):
self.config = config
self.acion_space = action_space
self.observation_space = observation_space
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dist_class, logit_dim = ModelCatalog.get_action_dist(
......@@ -44,8 +46,12 @@ class CustomTorchPolicy(TorchPolicy):
loss=None,
action_distribution_class=dist_class,
)
self.framework = "torch"
def init_training(self):
""" Init once only for the policy - Surely there should be a bette way to do this """
aux_params = set(self.model.aux_vf.parameters())
value_params = set(self.model.value_fc.parameters())
network_params = set(self.model.parameters())
......@@ -80,7 +86,7 @@ class CustomTorchPolicy(TorchPolicy):
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################")
replay_shape = (n_pi, nsteps, nenvs)
self.retune_selector = RetuneSelector(nenvs, observation_space, action_space, replay_shape,
self.retune_selector = RetuneSelector(nenvs, self.observation_space, self.action_space, replay_shape,
skips = self.config['skips'],
n_pi = n_pi,
num_retunes = self.config['num_retunes'],
......@@ -93,11 +99,11 @@ class CustomTorchPolicy(TorchPolicy):
self.gamma = self.config['gamma']
self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
self.lr = config['lr']
self.ent_coef = config['entropy_coeff']
self.lr = self.config['lr']
self.ent_coef = self.config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.make_distr = dist_build(action_space)
self.make_distr = dist_build(self.action_space)
self.retunes_completed = 0
self.amp_scaler = GradScaler()
......
......@@ -139,6 +139,9 @@ def build_trainer(name,
**optimizer_config)
if after_init:
after_init(self)
policy = Trainer.get_policy(self)
policy.init_training()
@override(Trainer)
def _train(self):
......@@ -192,11 +195,14 @@ def build_trainer(name,
state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy()
policy = Trainer.get_policy(self)
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
state["value_optimizer_state"] = {k: v for k, v in policy.value_optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
try:
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["aux_optimizer_state"] = {k: v for k, v in policy.aux_optimizer.state_dict().items()}
state["value_optimizer_state"] = {k: v for k, v in policy.value_optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
except:
print("################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
state["train_exec_impl"] = (
......@@ -207,9 +213,12 @@ def build_trainer(name,
Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"],
state["value_optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
try:
policy.set_optimizer_state(state["optimizer_state"], state["aux_optimizer_state"],
state["value_optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
except:
print("################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
......
......@@ -21,6 +21,8 @@ class CustomTorchPolicy(TorchPolicy):
def __init__(self, observation_space, action_space, config):
self.config = config
self.acion_space = action_space
self.observation_space = observation_space
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dist_class, logit_dim = ModelCatalog.get_action_dist(
......@@ -43,7 +45,11 @@ class CustomTorchPolicy(TorchPolicy):
loss=None,
action_distribution_class=dist_class,
)
self.framework = "torch"
def init_training(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
self.max_reward = self.config['env_config']['return_max']
self.rewnorm = RewardNormalizer(cliprew=self.max_reward) ## TODO: Might need to go to custom state
......@@ -64,11 +70,11 @@ class CustomTorchPolicy(TorchPolicy):
print("#################################################")
print("WARNING: MEMORY LIMITED BATCHING NOT SET PROPERLY")
print("#################################################")
self.retune_selector = RetuneSelector(self.nbatch, observation_space, action_space,
self.retune_selector = RetuneSelector(self.nbatch, self.observation_space, self.action_space,
skips = self.config['retune_skips'],
replay_size = self.config['retune_replay_size'],
num_retunes = self.config['num_retunes'])
self.exp_replay = np.zeros((self.retune_selector.replay_size, *observation_space.shape), dtype=np.uint8)
self.exp_replay = np.empty((self.retune_selector.replay_size, *self.observation_space.shape), dtype=np.uint8)
self.target_timesteps = 8_000_000
self.buffer_time = 20 # TODO: Could try to do a median or mean time step check instead
self.max_time = 10000000000000 # ignore timekeeping because spot instances are messing it up
......@@ -76,8 +82,8 @@ class CustomTorchPolicy(TorchPolicy):
self.gamma = self.config['gamma']
self.adaptive_discount_tuner = AdaptiveDiscountTuner(self.gamma, momentum=0.98, eplenmult=3)
self.lr = config['lr']
self.ent_coef = config['entropy_coeff']
self.lr = self.config['lr']
self.ent_coef = self.config['entropy_coeff']
self.last_dones = np.zeros((nw * self.config['num_envs_per_worker'],))
self.save_success = 0
......@@ -270,7 +276,6 @@ class CustomTorchPolicy(TorchPolicy):
self.to_tensor(replay_pi[mbinds])]
self.tune_policy(apply_grad, *slices, 0.5)
self.exp_replay.fill(0)
self.retunes_completed += 1
self.retune_selector.retune_done()
......
......@@ -139,6 +139,9 @@ def build_trainer(name,
**optimizer_config)
if after_init:
after_init(self)
policy = Trainer.get_policy(self)
policy.init_training()
@override(Trainer)
def _train(self):
......@@ -192,9 +195,12 @@ def build_trainer(name,
state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy()
policy = Trainer.get_policy(self)
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
try:
state["custom_state_vars"] = policy.get_custom_state_vars()
state["optimizer_state"] = {k: v for k, v in policy.optimizer.state_dict().items()}
state["amp_scaler_state"] = {k: v for k, v in policy.amp_scaler.state_dict().items()}
except:
print("################# WARNING: SAVING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
state["train_exec_impl"] = (
......@@ -205,8 +211,11 @@ def build_trainer(name,
Trainer.__setstate__(self, state)
policy = Trainer.get_policy(self)
self.state = state["trainer_state"].copy()
policy.set_optimizer_state(state["optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
try:
policy.set_optimizer_state(state["optimizer_state"], state["amp_scaler_state"])
policy.set_custom_state_vars(state["custom_state_vars"])
except:
print("################# WARNING: LOADING STATE VARS AND OPTIMIZER FAILED ################")
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
......
......@@ -8,7 +8,7 @@ procgen-ppo:
time_total_s: 7200
# === Settings for Checkpoints ===
checkpoint_freq: 1
checkpoint_freq: 100
checkpoint_at_end: True
keep_checkpoints_num: 5
......
......@@ -46,9 +46,9 @@ procgen-ppo:
no_done_at_end: False
# Custom switches
retune_skips: 100000
retune_skips: 50000
retune_replay_size: 200000
num_retunes: 23
num_retunes: 28
retune_epochs: 3
standardize_rewards: True
scale_reward: 1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment