diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index 989800044250d60b68c68b5d2e702b5625964024..f82cd42d9bbd836b681ff284a82f357b2760bb0c 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -2,8 +2,9 @@ import numpy as np from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils.seed import seed as set_seed -from flatland.envs.generators import complex_rail_generator, random_rail_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator class RailEnvRLLibWrapper(MultiAgentEnv): @@ -25,19 +26,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv): min_dist=config['min_dist'], nr_extra=config['nr_extra'], seed=config['seed'] * (1 + vector_index)) + self.schedule_generator = complex_schedule_generator() elif config['rail_generator'] == "random_rail_generator": self.rail_generator = random_rail_generator() + self.schedule_generator = random_schedule_generator() elif config['rail_generator'] == "load_env": self.predefined_env = True self.rail_generator = random_rail_generator() + self.schedule_generator = random_schedule_generator() else: raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}') set_seed(config['seed'] * (1 + vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], number_of_agents=config["number_of_agents"], - obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator) + obs_builder_object=config['obs_builder'], + rail_generator=self.rail_generator, + schedule_generator=self.schedule_generator + ) if self.predefined_env: self.env.load_resource('torch_training.railway', 'complex_scene.pkl') diff --git a/scoring/utils/misc_utils.py b/scoring/utils/misc_utils.py index b15476dda523a8663f712daf096803fdc14218df..dee5f47f7f8f09f253dfc3f8e3d48931df94efe7 100644 --- a/scoring/utils/misc_utils.py +++ b/scoring/utils/misc_utils.py @@ -1,19 +1,19 @@ import random import time -from collections import deque import numpy as np -from flatland.envs.generators import complex_rail_generator, rail_from_file -from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv + +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.utils.rendertools import RenderTool from flatland.envs.rail_env import RailEnv - -from utils.observation_utils import norm_obs_clip, split_tree, max_lt +from flatland.envs.rail_generators import complex_rail_generator, rail_from_file +from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.utils.rendertools import RenderTool # Time factor to test the max time allowed for an env. max_time_factor = 1 + def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'): """ Call in a loop to create terminal progress bar @@ -54,7 +54,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No random.seed(parameters[3]) np.random.seed(parameters[3]) - printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) for trial in range(nr_trials_per_test): # Reset the env @@ -73,7 +72,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No for a in range(env.get_num_agents()): obs[a] = observation_wrapper(obs[a]) - # Run episode trial_score = 0 max_steps = int(max_time_factor * (env.height + env.width)) @@ -115,6 +113,7 @@ def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=parameters[2]) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) @@ -151,6 +150,7 @@ def render_test(parameters, test_nr=0, nr_examples=5): env_renderer.close_window() return + def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): # Parameter initialization features_per_node = 9 @@ -168,7 +168,6 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): random.seed(parameters[3]) np.random.seed(parameters[3]) - printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) for trial in range(nr_trials_per_test): # Reset the env @@ -177,7 +176,8 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): env = RailEnv(width=3, height=3, rail_generator=rail_from_file(file_name), - obs_builder_object=TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()), + obs_builder_object=TreeObsForRailEnv(max_depth=tree_depth, + predictor=ShortestPathPredictorForRailEnv()), number_of_agents=1, ) @@ -185,7 +185,7 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): done = env.dones # Run episode trial_score = 0 - max_steps = int(max_time_factor* (env.height + env.width)) + max_steps = int(max_time_factor * (env.height + env.width)) for step in range(max_steps): # Action diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py index d0b9ce70f50465a58885a9b1feb754791bb49f34..ecba34a67ac37e5b88b9f7fcac34ea455c690078 100644 --- a/sequential_agent/run_test.py +++ b/sequential_agent/run_test.py @@ -1,10 +1,12 @@ -from sequential_agent.simple_order_agent import OrderedAgent -from flatland.envs.generators import rail_from_file, complex_rail_generator +import numpy as np + from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool -import numpy as np +from sequential_agent.simple_order_agent import OrderedAgent np.random.seed(2) """ @@ -29,6 +31,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index b310e95f19e6284cfd6fed6f195d9807551875b8..66a37ad290ade376f682fcb2f40f1a02533537dc 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,14 +3,14 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import rail_from_file -from flatland.envs.rail_env import RailEnv -from flatland.utils.rendertools import RenderTool from importlib_resources import path from observation_builders.observations import TreeObsForRailEnv from predictors.predictions import ShortestPathPredictorForRailEnv import torch_training.Nets +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_file +from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -41,6 +41,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index c3c8c2b6b49dfeee1eceafa52d2bc0b4f1ff366e..7da12e54a55b9afe802651475e43234c55fba12c 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -7,15 +7,16 @@ from collections import deque import matplotlib.pyplot as plt import numpy as np import torch -# Import Flatland/ Observations and Predictors -from flatland.envs.generators import complex_rail_generator -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv from importlib_resources import path # Import Torch and utility functions to normalize observation import torch_training.Nets +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +# Import Flatland/ Observations and Predictors +from flatland.envs.schedule_generators import complex_schedule_generator from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -56,6 +57,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -113,6 +115,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py index 1c3225c24dba4c52756b42f41e53656b1f698be3..08cd84c379fe54cd4d6b71140a96623ebe2a8cbf 100644 --- a/torch_training/multi_agent_two_time_step_training.py +++ b/torch_training/multi_agent_two_time_step_training.py @@ -7,15 +7,16 @@ from collections import deque import matplotlib.pyplot as plt import numpy as np import torch -# Import Flatland/ Observations and Predictors -from flatland.envs.generators import complex_rail_generator -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv from importlib_resources import path # Import Torch and utility functions to normalize observation import torch_training.Nets +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +# Import Flatland/ Observations and Predictors +from flatland.envs.schedule_generators import complex_schedule_generator from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -52,6 +53,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -109,6 +111,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 489501ae8ca43df9ca94a86a837e274181b207ee..3882cc9de4aefd89e294be18b0b2c1eee905cfcc 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -3,14 +3,15 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import complex_rail_generator +from importlib_resources import path + +import torch_training.Nets from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool -from importlib_resources import path - -import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -37,6 +38,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 5cc7305563076e54e25a3fb2d890e276bcbf9a38..fb03432fea29c0b4b16a02a90cf5fe509f873f7d 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -7,11 +7,12 @@ import matplotlib.pyplot as plt import numpy as np import torch from dueling_double_dqn import Agent -from flatland.envs.generators import complex_rail_generator + from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool - from utils.observation_utils import norm_obs_clip, split_tree @@ -44,6 +45,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=observation_builder, number_of_agents=n_agents) env.reset(True, True) diff --git a/utils/misc_utils.py b/utils/misc_utils.py index 5b29c6b15f61b46062bac8d4fb6c4130fe61c6ec..09b315cf2a851eadf84b1a31a0d773e46549673f 100644 --- a/utils/misc_utils.py +++ b/utils/misc_utils.py @@ -3,15 +3,16 @@ import time from collections import deque import numpy as np -from flatland.envs.generators import complex_rail_generator -from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_env import RailEnv from line_profiler import LineProfiler +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from utils.observation_utils import norm_obs_clip, split_tree -def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '*'): +def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'): """ Call in a loop to create terminal progress bar @params: @@ -31,13 +32,14 @@ def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, if iteration == total: print('') + class RandomAgent: def __init__(self, state_size, action_size): self.state_size = state_size self.action_size = action_size - def act(self, state, eps = 0): + def act(self, state, eps=0): """ :param state: input is the observation of the agent :return: returns an action @@ -87,6 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), + schedule_generator=complex_schedule_generator(), obs_builder_object=GlobalObsForRailEnv(), number_of_agents=parameters[2]) max_steps = int(3 * (env.height + env.width))