From 1cb2c042586553362f8d332ac7cb2c5817aac82e Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Tue, 14 May 2019 16:42:43 +0200 Subject: [PATCH] change in wrapper should solve bug --- RailEnvRLLibWrapper.py | 16 ++++++++-------- train.py | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py index 1b537a6..fef562d 100644 --- a/RailEnvRLLibWrapper.py +++ b/RailEnvRLLibWrapper.py @@ -6,15 +6,15 @@ from flatland.envs.generators import random_rail_generator class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv): - def __init__(self, - width, - height, - rail_generator=random_rail_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)): + def __init__(self, config): + # width, + # height, + # rail_generator=random_rail_generator(), + # number_of_agents=1, + # obs_builder_object=TreeObsForRailEnv(max_depth=2)): - super(RailEnvRLLibWrapper, self).__init__(width=width, height=height, rail_generator=rail_generator, - number_of_agents=number_of_agents, obs_builder_object=obs_builder_object) + super(RailEnvRLLibWrapper, self).__init__(width=config["width"], height=config["height"], rail_generator=config["rail_generator"], + number_of_agents=config["number_of_agents"]) def reset(self, regen_rail=True, replace_agents=True): self.agents_done = [] diff --git a/train.py b/train.py index 686ae5d..320e494 100644 --- a/train.py +++ b/train.py @@ -11,7 +11,7 @@ from flatland.envs.generators import complex_rail_generator import ray.rllib.agents.ppo.ppo as ppo import ray.rllib.agents.dqn.dqn as dqn -from ray.rllib.agents.ppo.ppo import PPOAgent +from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.dqn.dqn import DQNAgent from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph @@ -64,10 +64,10 @@ def train(config): rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) """ - env = RailEnvRLLibWrapper(width=20, - height=20, - rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0), - number_of_agents=5) + env_config = {"width": 20, + "height":20, + "rail_generator":complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0), + "number_of_agents":5} """ env = RailEnv(width=20, height=20, @@ -77,7 +77,6 @@ def train(config): """ - register_env("railenv", lambda _: env) # if config['render']: # env_renderer = RenderTool(env, gl="QT") # plt.figure(figsize=(5,5)) @@ -105,9 +104,10 @@ def train(config): # agent_config["num_gpus_per_worker"] = 2.0 agent_config["num_cpus_for_driver"] = 5 agent_config["num_envs_per_worker"] = 15 + agent_config["env_config"] = env_config #agent_config["batch_mode"] = "complete_episodes" - ppo_trainer = PPOAgent(env=f"railenv", config=agent_config) + ppo_trainer = PPOTrainer(env=f"railenv", config=agent_config) for i in range(100000 + 2): print("== Iteration", i, "==") -- GitLab