Skip to content
Snippets Groups Projects
Commit 1c1e6dc8 authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

bug with multiple environments

parent c4df1ca0
No related branches found
No related tags found
No related merge requests found
......@@ -13,8 +13,8 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv):
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(RailEnvRLLibWrapper, self).__init__(width, height, rail_generator,
number_of_agents, obs_builder_object)
super(RailEnvRLLibWrapper, self).__init__(width=width, height=height, rail_generator=rail_generator,
number_of_agents=number_of_agents, obs_builder_object=obs_builder_object)
def reset(self, regen_rail=True, replace_agents=True):
self.agents_done = []
......@@ -32,17 +32,16 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv):
if agent != '__all__':
o[agent] = obs[agent]
r[agent] = rewards[agent]
d[agent] = dones[agent]
# obs.pop(agent_done)
# rewards.pop(agent_done)
# dones.pop(agent_done)
for agent, done in dones.items():
if done and agent != '__all__':
self.agents_done.append(agent)
#print(obs)
#return obs, rewards, dones, infos
return o, r, d, infos
def get_agent_handles(self):
return super(RailEnvRLLibWrapper, self).get_agent_handles()
......@@ -64,9 +64,9 @@ def train(config):
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
env = RailEnvRLLibWrapper(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0),
number_of_agents=5)
"""
env = RailEnv(width=20,
......@@ -94,11 +94,18 @@ def train(config):
return f"ppo_policy"
agent_config = ppo.DEFAULT_CONFIG.copy()
agent_config['model'] = {"fcnet_hiddens": [32, 32], "custom_preprocessor": "my_prep"}
agent_config['model'] = {"fcnet_hiddens": [32, 32]}#, "custom_preprocessor": "my_prep"}
agent_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = 50
#agent_config["num_workers"] = 0
#agent_config["num_cpus_per_worker"] = 40
#agent_config["num_gpus"] = 2.0
# agent_config["num_gpus_per_worker"] = 2.0
agent_config["num_cpus_for_driver"] = 5
agent_config["num_envs_per_worker"] = 15
#agent_config["batch_mode"] = "complete_episodes"
ppo_trainer = PPOAgent(env=f"railenv", config=agent_config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment