diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index eb9bbd6b4fb72321d0b745922799e227abd4184e..f82cd42d9bbd836b681ff284a82f357b2760bb0c 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -4,6 +4,7 @@ from ray.rllib.utils.seed import seed as set_seed from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator class RailEnvRLLibWrapper(MultiAgentEnv): @@ -25,19 +26,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv): min_dist=config['min_dist'], nr_extra=config['nr_extra'], seed=config['seed'] * (1 + vector_index)) + self.schedule_generator = complex_schedule_generator() elif config['rail_generator'] == "random_rail_generator": self.rail_generator = random_rail_generator() + self.schedule_generator = random_schedule_generator() elif config['rail_generator'] == "load_env": self.predefined_env = True self.rail_generator = random_rail_generator() + self.schedule_generator = random_schedule_generator() else: raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}') set_seed(config['seed'] * (1 + vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], number_of_agents=config["number_of_agents"], - obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator) + obs_builder_object=config['obs_builder'], + rail_generator=self.rail_generator, + schedule_generator=self.schedule_generator + ) if self.predefined_env: self.env.load_resource('torch_training.railway', 'complex_scene.pkl') diff --git a/scoring/utils/misc_utils.py b/scoring/utils/misc_utils.py index e6f9195dcbf4adf242e18adbe02c150ba46cd3f5..dee5f47f7f8f09f253dfc3f8e3d48931df94efe7 100644 --- a/scoring/utils/misc_utils.py +++ b/scoring/utils/misc_utils.py @@ -7,7 +7,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator, rail_from_file -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool # Time factor to test the max time allowed for an env. @@ -113,7 +113,7 @@ def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=parameters[2]) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py index 31b21e896aa99993a53aea6be4e27e3493ca8105..ecba34a67ac37e5b88b9f7fcac34ea455c690078 100644 --- a/sequential_agent/run_test.py +++ b/sequential_agent/run_test.py @@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool from sequential_agent.simple_order_agent import OrderedAgent @@ -31,7 +31,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 6e62687ec23043765ae5e90af1895e1b3e153b49..66a37ad290ade376f682fcb2f40f1a02533537dc 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -41,7 +41,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 1e545ecc4fbf0e3af30414b7a84ac53912122b76..3a4c4318db250b681f1c3455e99b9ff75cefae07 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -16,7 +16,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator # Import Flatland/ Observations and Predictors -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -57,7 +57,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -115,9 +115,9 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, - number_of_agents=n_agents) + number_of_agents=n_agents)f # Adjust the parameters according to the new env. max_steps = int((env.height + env.width)) diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py index d0aafa3b6ec148f80a1ee8357407371d0a69dbd9..08cd84c379fe54cd4d6b71140a96623ebe2a8cbf 100644 --- a/torch_training/multi_agent_two_time_step_training.py +++ b/torch_training/multi_agent_two_time_step_training.py @@ -16,7 +16,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator # Import Flatland/ Observations and Predictors -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -53,7 +53,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -111,7 +111,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 3b48d6f5429f70fcd212846c05359261c4971e50..3882cc9de4aefd89e294be18b0b2c1eee905cfcc 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -10,7 +10,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -38,7 +38,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 623beb5b77083df36411bf26d18538267d50fe22..fb03432fea29c0b4b16a02a90cf5fe509f873f7d 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -11,7 +11,7 @@ from dueling_double_dqn import Agent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool from utils.observation_utils import norm_obs_clip, split_tree @@ -45,7 +45,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_builder, number_of_agents=n_agents) env.reset(True, True) diff --git a/utils/misc_utils.py b/utils/misc_utils.py index 19c704bc19f302b1e6a74ef93b1715e154dff544..09b315cf2a851eadf84b1a31a0d773e46549673f 100644 --- a/utils/misc_utils.py +++ b/utils/misc_utils.py @@ -8,7 +8,7 @@ from line_profiler import LineProfiler from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from utils.observation_utils import norm_obs_clip, split_tree @@ -89,7 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=GlobalObsForRailEnv(), number_of_agents=parameters[2]) max_steps = int(3 * (env.height + env.width))