From 3e8a81946f099776f2e7e4df8493be07920129cc Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 28 Aug 2019 13:41:20 +0200 Subject: [PATCH] merge #141 renamed agent_generator* to schedule_generator* --- RLLib_training/RailEnvRLLibWrapper.py | 9 ++++++++- scoring/utils/misc_utils.py | 4 ++-- sequential_agent/run_test.py | 4 ++-- torch_training/multi_agent_inference.py | 2 +- torch_training/multi_agent_training.py | 8 ++++---- torch_training/multi_agent_two_time_step_training.py | 6 +++--- torch_training/render_agent_behavior.py | 4 ++-- torch_training/training_navigation.py | 4 ++-- utils/misc_utils.py | 4 ++-- 9 files changed, 26 insertions(+), 19 deletions(-) diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index eb9bbd6..f82cd42 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -4,6 +4,7 @@ from ray.rllib.utils.seed import seed as set_seed from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator class RailEnvRLLibWrapper(MultiAgentEnv): @@ -25,19 +26,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv): min_dist=config['min_dist'], nr_extra=config['nr_extra'], seed=config['seed'] * (1 + vector_index)) + self.schedule_generator = complex_schedule_generator() elif config['rail_generator'] == "random_rail_generator": self.rail_generator = random_rail_generator() + self.schedule_generator = random_schedule_generator() elif config['rail_generator'] == "load_env": self.predefined_env = True self.rail_generator = random_rail_generator() + self.schedule_generator = random_schedule_generator() else: raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}') set_seed(config['seed'] * (1 + vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], number_of_agents=config["number_of_agents"], - obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator) + obs_builder_object=config['obs_builder'], + rail_generator=self.rail_generator, + schedule_generator=self.schedule_generator + ) if self.predefined_env: self.env.load_resource('torch_training.railway', 'complex_scene.pkl') diff --git a/scoring/utils/misc_utils.py b/scoring/utils/misc_utils.py index e6f9195..dee5f47 100644 --- a/scoring/utils/misc_utils.py +++ b/scoring/utils/misc_utils.py @@ -7,7 +7,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator, rail_from_file -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool # Time factor to test the max time allowed for an env. @@ -113,7 +113,7 @@ def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=parameters[2]) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py index 31b21e8..ecba34a 100644 --- a/sequential_agent/run_test.py +++ b/sequential_agent/run_test.py @@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool from sequential_agent.simple_order_agent import OrderedAgent @@ -31,7 +31,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 6e62687..66a37ad 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -41,7 +41,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 1e545ec..3a4c431 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -16,7 +16,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator # Import Flatland/ Observations and Predictors -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -57,7 +57,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -115,9 +115,9 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, - number_of_agents=n_agents) + number_of_agents=n_agents)f # Adjust the parameters according to the new env. max_steps = int((env.height + env.width)) diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py index d0aafa3..08cd84c 100644 --- a/torch_training/multi_agent_two_time_step_training.py +++ b/torch_training/multi_agent_two_time_step_training.py @@ -16,7 +16,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator # Import Flatland/ Observations and Predictors -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -53,7 +53,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -111,7 +111,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 3b48d6f..3882cc9 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -10,7 +10,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -38,7 +38,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 623beb5..fb03432 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -11,7 +11,7 @@ from dueling_double_dqn import Agent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool from utils.observation_utils import norm_obs_clip, split_tree @@ -45,7 +45,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_builder, number_of_agents=n_agents) env.reset(True, True) diff --git a/utils/misc_utils.py b/utils/misc_utils.py index 19c704b..09b315c 100644 --- a/utils/misc_utils.py +++ b/utils/misc_utils.py @@ -8,7 +8,7 @@ from line_profiler import LineProfiler from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from utils.observation_utils import norm_obs_clip, split_tree @@ -89,7 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=GlobalObsForRailEnv(), number_of_agents=parameters[2]) max_steps = int(3 * (env.height + env.width)) -- GitLab