From 1c1e6dc8706447fc9ab3bad1b4e2fd7ab01f87d8 Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume@iccluster091.iccluster.epfl.ch> Date: Tue, 14 May 2019 16:35:29 +0200 Subject: [PATCH] bug with multiple environments --- RailEnvRLLibWrapper.py | 15 +++++++-------- train.py | 15 +++++++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py index 3649434..1b537a6 100644 --- a/RailEnvRLLibWrapper.py +++ b/RailEnvRLLibWrapper.py @@ -13,8 +13,8 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2)): - super(RailEnvRLLibWrapper, self).__init__(width, height, rail_generator, - number_of_agents, obs_builder_object) + super(RailEnvRLLibWrapper, self).__init__(width=width, height=height, rail_generator=rail_generator, + number_of_agents=number_of_agents, obs_builder_object=obs_builder_object) def reset(self, regen_rail=True, replace_agents=True): self.agents_done = [] @@ -32,17 +32,16 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv): if agent != '__all__': o[agent] = obs[agent] r[agent] = rewards[agent] - + d[agent] = dones[agent] - # obs.pop(agent_done) - # rewards.pop(agent_done) - # dones.pop(agent_done) for agent, done in dones.items(): if done and agent != '__all__': self.agents_done.append(agent) - + + #print(obs) + #return obs, rewards, dones, infos return o, r, d, infos - + def get_agent_handles(self): return super(RailEnvRLLibWrapper, self).get_agent_handles() diff --git a/train.py b/train.py index 71f214c..686ae5d 100644 --- a/train.py +++ b/train.py @@ -64,9 +64,9 @@ def train(config): rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) """ - env = RailEnv(width=50, - height=50, - rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0), + env = RailEnvRLLibWrapper(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0), number_of_agents=5) """ env = RailEnv(width=20, @@ -94,11 +94,18 @@ def train(config): return f"ppo_policy" agent_config = ppo.DEFAULT_CONFIG.copy() - agent_config['model'] = {"fcnet_hiddens": [32, 32], "custom_preprocessor": "my_prep"} + agent_config['model'] = {"fcnet_hiddens": [32, 32]}#, "custom_preprocessor": "my_prep"} agent_config['multiagent'] = {"policy_graphs": policy_graphs, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": list(policy_graphs.keys())} agent_config["horizon"] = 50 + #agent_config["num_workers"] = 0 + #agent_config["num_cpus_per_worker"] = 40 + #agent_config["num_gpus"] = 2.0 + # agent_config["num_gpus_per_worker"] = 2.0 + agent_config["num_cpus_for_driver"] = 5 + agent_config["num_envs_per_worker"] = 15 + #agent_config["batch_mode"] = "complete_episodes" ppo_trainer = PPOAgent(env=f"railenv", config=agent_config) -- GitLab