Commit 86b99ebb authored by nilabha's avatar nilabha

include support for checkpointing

parent c6f4a5a5
Pipeline #4964 passed with stage
in 23 minutes and 49 seconds
......@@ -97,6 +97,7 @@ class ImitationAgent(PPOTrainer):
@override(Trainer)
def _init(self, config, env_creator):
self.env = env_creator(config["env_config"])
self.state = {}
self._policy = ImitationTFPolicy
action_space = self.env.action_space
dist_class, logit_dim = ModelCatalog.get_action_dist(
......@@ -106,6 +107,72 @@ class ImitationAgent(PPOTrainer):
self.execution_plan = default_execution_plan
self.train_exec_impl = self.execution_plan(self.workers, config)
def eval(self):
import tensorflow as tf
policy = self.get_policy()
steps = 0
all_scores = 0
all_completion = 0
eval_episodes = 2
for _ in range(eval_episodes):
env = self.env._env.rail_env
obs = self.env.reset()
num_outputs = env.action_space[0]
n_agents = env.get_num_agents()
# TODO : Update max_steps as per latest version
# https://gitlab.aicrowd.com/flatland/flatland-examples/blob/master/reinforcement_learning/multi_agent_training.py
# max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities))) - 1
max_steps = int(4 * 2 * (20 + env.height + env.width))
episode_steps = 0
episode_max_steps = 0
episode_num_agents = 0
episode_score = 0
episode_done_agents = 0
# obs = self.env.reset()
done = {}
done["__all__"] = False
action_dict = {i:2 for i in range(n_agents)}
# TODO: Support for batch update
# batch_size = 2
# logits, _ = policy.model.forward({"obs": np.vstack([obs[a],obs[a]])}, [], None)
# while not done["__all__"]:
for step in range(max_steps):
for a in range(n_agents):
if not done.get(a) and obs.get(a) is not None:
input_dict = {"obs": np.expand_dims(obs[a],0)}
input_dict['obs_flat'] = input_dict['obs']
logits, _ = policy.model.forward(input_dict, [], None)
model_logits = tf.squeeze(logits)
action_dict[a] = tf.math.argmax(model_logits).numpy()
obs, all_rewards, done, info = self.env.step(action_dict)
steps += 1
#super()._train()
for agent, agent_info in info.items():
if episode_max_steps == 0:
episode_max_steps = agent_info["max_episode_steps"]
episode_num_agents = agent_info["num_agents"]
episode_steps = max(episode_steps, agent_info["agent_step"])
episode_score += agent_info["agent_score"]
if agent_info["agent_done"]:
episode_done_agents += 1
if done["__all__"]:
all_scores += episode_score
all_completion += float(episode_done_agents) / n_agents
break
return {
"episode_reward_mean": all_scores/eval_episodes,
"episode_completion_mean": all_completion/eval_episodes,
"timesteps_this_iter": steps,
}
@override(Trainer)
def _train(self):
import tensorflow as tf
......@@ -231,8 +298,8 @@ if __name__ == "__main__":
exp["config"])
ray.init(num_cpus=1,num_gpus=0)
trainer1 = ImitationAgent(_default_config,
ray.init(num_cpus=3,num_gpus=0)
imitation_trainer = ImitationAgent(_default_config,
env="flatland_sparse",)
# default_policy=ImitationPolicy,
# get_policy_class=ImitationPolicy)
......@@ -243,9 +310,19 @@ if __name__ == "__main__":
# trainer = PPOTrainer(_default_config,
# env="flatland_sparse",)
result = trainer1.train()
for i in range(10):
result = imitation_trainer.train()
if i % 5:
eval_results = imitation_trainer.eval()
print("Eval Results:",eval_results)
checkpoint = imitation_trainer.save()
# TODO: Loads weights but not optimizer state
# Could be done by overriding _save by using model.save_weight(checkpoint)
# Also override _restore. Ideally use workers to save/load weights.
# imitation_trainer.restore(checkpoint)
print("checkpoint saved at", checkpoint)
imitation_trainer.stop()
# registry.register_trainable('ImitationPolicyTrainer',ImitationAgent)
# ImitationPolicyTrainer = build_trainer(
......
......@@ -95,6 +95,7 @@ class ImitationAgent(PPOTrainer):
@override(Trainer)
def _init(self, config, env_creator):
self.env = env_creator(config["env_config"])
self.state = {}
self._policy = ImitationTFPolicy
action_space = self.env.action_space
dist_class, logit_dim = ModelCatalog.get_action_dist(
......@@ -218,7 +219,7 @@ if __name__ == "__main__":
ppo_trainer = PPOTrainer(_default_config,
env="flatland_sparse",)
for i in range(5):
for i in range(10):
print("== Iteration", i, "==")
trainer_type = np.random.binomial(size=1, n=1, p= 0.5)[0]
......@@ -227,12 +228,16 @@ if __name__ == "__main__":
# improve the Imitation policy
print("-- Imitation --")
print(pretty_print(imitation_trainer.train()))
checkpoint = imitation_trainer.save()
print("checkpoint saved at", checkpoint)
ppo_trainer.set_weights(imitation_trainer.get_weights())
else:
# improve the PPO policy
print("-- PPO --")
print(pretty_print(ppo_trainer.train()))
checkpoint = ppo_trainer.save()
print("checkpoint saved at", checkpoint)
imitation_trainer.set_weights(ppo_trainer.get_weights())
print("Done: OK")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment