diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py index 36494342b5fb2ef2c23e083dd6727e3a4e38f01d..1b537a6ae324d1ba34c62894c8407ce4f4b3ba70 100644 --- a/RailEnvRLLibWrapper.py +++ b/RailEnvRLLibWrapper.py @@ -13,8 +13,8 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2)): - super(RailEnvRLLibWrapper, self).__init__(width, height, rail_generator, - number_of_agents, obs_builder_object) + super(RailEnvRLLibWrapper, self).__init__(width=width, height=height, rail_generator=rail_generator, + number_of_agents=number_of_agents, obs_builder_object=obs_builder_object) def reset(self, regen_rail=True, replace_agents=True): self.agents_done = [] @@ -32,17 +32,16 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv): if agent != '__all__': o[agent] = obs[agent] r[agent] = rewards[agent] - + d[agent] = dones[agent] - # obs.pop(agent_done) - # rewards.pop(agent_done) - # dones.pop(agent_done) for agent, done in dones.items(): if done and agent != '__all__': self.agents_done.append(agent) - + + #print(obs) + #return obs, rewards, dones, infos return o, r, d, infos - + def get_agent_handles(self): return super(RailEnvRLLibWrapper, self).get_agent_handles() diff --git a/train.py b/train.py index 71f214c15f58f9daa30a1f6919c8fed22bc3c566..686ae5d0ea8c95e0b55dc663f6adcc90b99b72fe 100644 --- a/train.py +++ b/train.py @@ -64,9 +64,9 @@ def train(config): rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) """ - env = RailEnv(width=50, - height=50, - rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0), + env = RailEnvRLLibWrapper(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0), number_of_agents=5) """ env = RailEnv(width=20, @@ -94,11 +94,18 @@ def train(config): return f"ppo_policy" agent_config = ppo.DEFAULT_CONFIG.copy() - agent_config['model'] = {"fcnet_hiddens": [32, 32], "custom_preprocessor": "my_prep"} + agent_config['model'] = {"fcnet_hiddens": [32, 32]}#, "custom_preprocessor": "my_prep"} agent_config['multiagent'] = {"policy_graphs": policy_graphs, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": list(policy_graphs.keys())} agent_config["horizon"] = 50 + #agent_config["num_workers"] = 0 + #agent_config["num_cpus_per_worker"] = 40 + #agent_config["num_gpus"] = 2.0 + # agent_config["num_gpus_per_worker"] = 2.0 + agent_config["num_cpus_for_driver"] = 5 + agent_config["num_envs_per_worker"] = 15 + #agent_config["batch_mode"] = "complete_episodes" ppo_trainer = PPOAgent(env=f"railenv", config=agent_config)