diff --git a/README.md b/README.md index 02209ab4308f9dc65c84fc744de48f328ad5dfe4..f81370b950d1c3f9d40cbdde5d64f982df8de6a6 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ env = RailEnv(width=width, rail_generator=rail_generator, schedule_generator=schedule_generator, number_of_agents=nr_trains, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator=stochastic_data, # Malfunction data generator obs_builder_object=observation_builder, remove_agents_at_target=True # Removes agents at the end of their journey to make space for others ) diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index 87f248ee223b061fca8375d56072ed81d5ad338c..7a203baf87ebc07a3e5a1afad8606bdd98a8cc83 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -14,10 +14,8 @@ def run_benchmark(): np.random.seed(1) # Example generate a random rail - env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), - schedule_generator=complex_schedule_generator(), - number_of_agents=5) + env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + schedule_generator=complex_schedule_generator(), number_of_agents=5) env.reset() n_trials = 20 diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py index 600b8f0968fb2226ec92c224c32ffd7617138e5b..2cee8e867aaec93e1a42dba562be787ab03fa5cd 100644 --- a/examples/custom_observation_example_01_SimpleObs.py +++ b/examples/custom_observation_example_01_SimpleObs.py @@ -28,10 +28,7 @@ class SimpleObs(ObservationBuilder): def main(): - env = RailEnv(width=7, - height=7, - rail_generator=random_rail_generator(), - number_of_agents=3, + env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(), number_of_agents=3, obs_builder_object=SimpleObs()) env.reset() diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index b1729296a199fded38270a63a527de06d9e7b329..52a56b06dcdeb9bebeb9dbb70d5acc9ae0c350bb 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -76,13 +76,10 @@ def main(args): else: assert False, "unhandled option" - env = RailEnv(width=7, - height=7, + env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, - seed=1), - schedule_generator=complex_schedule_generator(), - number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs()) + seed=1), schedule_generator=complex_schedule_generator(), + number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 7af7499af8ffc5900a87d7e543310aab2a6df7f9..ac99835368fc4a5a709e72234396a286f1d622a7 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -122,13 +122,10 @@ def main(args): custom_obs_builder = ObservePredictions(custom_predictor) # Initiate Environment - env = RailEnv(width=10, - height=10, + env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, - seed=1), - schedule_generator=complex_schedule_generator(), - number_of_agents=3, - obs_builder_object=custom_obs_builder) + seed=1), schedule_generator=complex_schedule_generator(), + number_of_agents=3, obs_builder_object=custom_obs_builder) obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index ed263ef93e80f4d4a04db240d5e21c6e855806f8..ceea22a94fd1c4803a73fe57230393116766c8bd 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -43,10 +43,7 @@ def custom_schedule_generator() -> ScheduleGenerator: return generator -env = RailEnv(width=6, - height=4, - rail_generator=custom_rail_generator(), - schedule_generator=custom_schedule_generator(), +env = RailEnv(width=6, height=4, rail_generator=custom_rail_generator(), schedule_generator=custom_schedule_generator(), number_of_agents=1) env.reset() diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 5ece03e9c56d672b76a453e0036f6b89c3a6ee77..5556a2a0ed9c5b67e2708c8bf222304603a131ad 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -30,21 +30,16 @@ speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train -env = RailEnv(width=100, - height=100, - rail_generator=sparse_rail_generator(max_num_cities=30, - # Number of cities in map (where train stations are) - seed=14, # Random seed - grid_mode=False, - max_rails_between_cities=2, - max_rails_in_city=8, - ), - schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=100, - stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=GlobalObsForRailEnv(), - remove_agents_at_target=True - ) +env = RailEnv(width=100, height=100, rail_generator=sparse_rail_generator(max_num_cities=30, + # Number of cities in map (where train stations are) + seed=14, # Random seed + grid_mode=False, + max_rails_between_cities=2, + max_rails_in_city=8, + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100, + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data, + remove_agents_at_target=True) # RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0) diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 4cdf63b0b32d431efabc05ed4d593aa746a44b40..de7c77faebc1f9d8fb8bb9b18fec142601babcb7 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -72,15 +72,9 @@ observation_builder = GlobalObsForRailEnv() # observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Construct the enviornment with the given observation, generataors, predictors, and stochastic data -env = RailEnv(width=width, - height=height, - rail_generator=rail_generator, - schedule_generator=schedule_generator, - number_of_agents=nr_trains, - stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=observation_builder, - remove_agents_at_target=True # Removes agents at the end of their journey to make space for others - ) +env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, + number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=stochastic_data, + remove_agents_at_target=True) env.reset() # Initiate the renderer diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index 388128d0d246d73f0236b054a3228ec20c46864e..ba88930142f8344b13a3cae8de2178148e459998 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -9,10 +9,7 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)], [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]] -env = RailEnv(width=6, - height=4, - rail_generator=rail_from_manual_specifications_generator(specs), - number_of_agents=1) +env = RailEnv(width=6, height=4, rail_generator=rail_from_manual_specifications_generator(specs), number_of_agents=1) env.reset() diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index 34abee096a043b73f53de8eed42a2e2b73ec1cc5..f9659cbdb666a2a6bd94db8aa560ded40f69079a 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -23,8 +23,7 @@ transition_probability = [1.0, # empty cell - Case 0 1.0] # Case 10 - mirrored switch # Example generate a random rail -env = RailEnv(width=10, - height=10, +env = RailEnv(width=10, height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=3) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index ccbe8682fe5c8744737a452c77257dd4570b6f75..82fca31943a6f60607dfbc8e6befdac185c3acc6 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -11,11 +11,9 @@ from flatland.utils.rendertools import RenderTool random.seed(1) np.random.seed(1) -env = RailEnv(width=7, - height=7, +env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=1), - schedule_generator=complex_schedule_generator(), - number_of_agents=2, + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) env.reset() diff --git a/examples/training_example.py b/examples/training_example.py index 2ce2ad1a86dd85acf00926f413c941df84d973c7..5f8cbe4088b1358e13a323a7c665ac8ccf60f740 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -14,12 +14,9 @@ np.random.seed(1) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) -env = RailEnv(width=20, - height=20, +env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=1), - schedule_generator=complex_schedule_generator(), - obs_builder_object=TreeObservation, - number_of_agents=3) + schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObservation) env.reset() env_renderer = RenderTool(env, gl="PILSVG", ) diff --git a/flatland/cli.py b/flatland/cli.py index f544aabcb6d9e81ddf8703c59b8bf07324b3ce2c..cc7576d16a02b0d0268ecaab201921b5034d7ee0 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -18,16 +18,11 @@ from flatland.utils.rendertools import RenderTool @click.command() def demo(args=None): """Demo script to check installation""" - env = RailEnv( - width=15, - height=15, - rail_generator=complex_rail_generator( - nr_start_goal=10, - nr_extra=1, - min_dist=8, - max_dist=99999), - schedule_generator=complex_schedule_generator(), - number_of_agents=5) + env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator( + nr_start_goal=10, + nr_extra=1, + min_dist=8, + max_dist=99999), schedule_generator=complex_schedule_generator(), number_of_agents=5) env._max_episode_steps = int(15 * (env.width + env.height)) env_renderer = RenderTool(env) diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..0de2f4b598023884a49a47473670bb8613c71eeb --- /dev/null +++ b/flatland/envs/malfunction_generators.py @@ -0,0 +1,79 @@ +"""Malfunction generators for rail systems""" + +from typing import Tuple, List, Callable + +import msgpack + +MalfunctionGenerator = Callable[[], Tuple[float, int, int]] + + +def malfunction_from_file(filename) -> MalfunctionGenerator: + """ + Utility to load pickle file + + Parameters + ---------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + """ + + def generator(): + with open(filename, "rb") as file_in: + load_data = file_in.read() + data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') + + if "malfunction" in data: + # Mean malfunction in number of time steps + mean_malfunction_rate = data["malfunction"]["malfunction_rate"] + # Uniform distribution parameters for malfunction duration + min_number_of_steps_broken = data["malfunction"]["min_duration"] + max_number_of_steps_broken = data["malfunction"]["max_duration"] + agents_speed = None + return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + + return generator + +def malfunction_from_params(parameters) -> MalfunctionGenerator: + """ + Utility to load malfunction from parameters + + Parameters + ---------- + parameters containing + malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks + min_duration : int minimal duration of a failure + max_number_of_steps_broken : int maximal duration of a failure + + Returns + ------- + Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + """ + + def generator(): + mean_malfunction_rate = parameters['malfunction_rate'] + min_number_of_steps_broken = parameters['min_duration'] + max_number_of_steps_broken = parameters['max_duration'] + return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + + return generator + +def no_malfunction_generator() -> MalfunctionGenerator: + """ + Utility to load malfunction from parameters + + Parameters + ---------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + """ + + def generator(): + return 0, 0, 0 + + return generator diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index f284f3ac38b480d8feb6c0b4944cf8831d2a70d2..6da778ed40e917833c9460e9e08a8e4a516e5611 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -19,6 +19,7 @@ from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap +from flatland.envs.malfunction_generators import MalfunctionGenerator, no_malfunction_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator @@ -111,17 +112,15 @@ class RailEnv(Environment): stop_penalty = 0 # penalty for stopping a moving agent start_penalty = 0 # penalty for starting a stopped agent - def __init__(self, - width, + def __init__(self, width, height, rail_generator: RailGenerator = random_rail_generator(), schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), - stochastic_data=None, + malfunction_generator: MalfunctionGenerator = no_malfunction_generator(), remove_agents_at_target=True, - random_seed=1 - ): + random_seed=1): """ Environment init. @@ -161,6 +160,7 @@ class RailEnv(Environment): self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator + self.malfunction_generator: MalfunctionGenerator = malfunction_generator self.rail: Optional[GridTransitionMap] = None self.width = width self.height = height @@ -196,19 +196,8 @@ class RailEnv(Environment): self._seed(seed=random_seed) # Stochastic train malfunctioning parameters - if stochastic_data is not None: - mean_malfunction_rate = stochastic_data['malfunction_rate'] - malfunction_min_duration = stochastic_data['min_duration'] - malfunction_max_duration = stochastic_data['max_duration'] - else: - mean_malfunction_rate = 0. - malfunction_min_duration = 0. - malfunction_max_duration = 0. - - # Mean malfunction in number of time steps + mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator() self.mean_malfunction_rate = mean_malfunction_rate - - # Uniform distribution parameters for malfunction duration self.min_number_of_steps_broken = malfunction_min_duration self.max_number_of_steps_broken = malfunction_max_duration @@ -359,6 +348,12 @@ class RailEnv(Environment): else: self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) + # Stochastic train malfunctioning parameters + mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator() + self.mean_malfunction_rate = mean_malfunction_rate + self.min_number_of_steps_broken = malfunction_min_duration + self.max_number_of_steps_broken = malfunction_max_duration + self.agent_positions = np.full((self.height, self.width), False) self.restart_agents() @@ -837,22 +832,41 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] + malfunction_data = {"malfunction_rate": self.mean_malfunction_rate, + "min_duration": self.min_number_of_steps_broken, + "max_duration": self.max_number_of_steps_broken} + msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True) msg_data = { "grid": grid_data, "agents_static": agent_static_data, - "agents": agent_data} + "agents": agent_data, + "malfunction": malfunction_data} return msgpack.packb(msg_data, use_bin_type=True) - def get_agent_state_msg(self): + def get_full_state_dist_msg(self): """ - Returns agents information in msgpack object + Returns environment information with distance map information as msgpack object """ + grid_data = self.rail.grid.tolist() + agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] + msgpack.packb(grid_data, use_bin_type=True) + msgpack.packb(agent_data, use_bin_type=True) + msgpack.packb(agent_static_data, use_bin_type=True) + distance_map_data = self.distance_map.get() + malfunction_data = {"malfunction_rate": self.mean_malfunction_rate, + "min_duration": self.min_number_of_steps_broken, + "max_duration": self.max_number_of_steps_broken} + msgpack.packb(distance_map_data, use_bin_type=True) msg_data = { - "agents": agent_data} + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data, + "distance_map": distance_map_data, + "malfunction": malfunction_data} return msgpack.packb(msg_data, use_bin_type=True) def set_full_state_msg(self, msg_data): @@ -873,6 +887,12 @@ class RailEnv(Environment): self.rail.height = self.height self.rail.width = self.width self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + if "malfunction" in data: + # Mean malfunction in number of time steps + self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"] + # Uniform distribution parameters for malfunction duration + self.min_number_of_steps_broken = data["malfunction"]["min_duration"] + self.max_number_of_steps_broken = data["malfunction"]["max_duration"] def set_full_state_dist_msg(self, msg_data): """ @@ -894,26 +914,12 @@ class RailEnv(Environment): self.rail.height = self.height self.rail.width = self.width self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - - def get_full_state_dist_msg(self): - """ - Returns environment information with distance map information as msgpack object - """ - grid_data = self.rail.grid.tolist() - agent_static_data = [agent.to_list() for agent in self.agents_static] - agent_data = [agent.to_list() for agent in self.agents] - msgpack.packb(grid_data, use_bin_type=True) - msgpack.packb(agent_data, use_bin_type=True) - msgpack.packb(agent_static_data, use_bin_type=True) - distance_map_data = self.distance_map.get() - msgpack.packb(distance_map_data, use_bin_type=True) - msg_data = { - "grid": grid_data, - "agents_static": agent_static_data, - "agents": agent_data, - "distance_map": distance_map_data} - - return msgpack.packb(msg_data, use_bin_type=True) + if "malfunction" in data: + # Mean malfunction in number of time steps + self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"] + # Uniform distribution parameters for malfunction duration + self.min_number_of_steps_broken = data["malfunction"]["min_duration"] + self.max_number_of_steps_broken = data["malfunction"]["max_duration"] def save(self, filename, save_distance_maps=False): """ diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index dc1cff12c0c8b1860859208a13d6403734a2d2ad..525fd6564b2a76e63641d38b7c93823ec2c83153 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -10,10 +10,7 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b obs_builder_object = TreeObsForRailEnv( max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)) - environment = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name, load_from_package), - number_of_agents=1, - schedule_generator=schedule_from_file(file_name, load_from_package), + environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), + schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1, obs_builder_object=obs_builder_object) return environment diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index d41224f36f5ffc974e976bb362e8ba8050df4e7d..7b2e189971a364afc55cd62cd2e83bb69c9fbd89 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -217,13 +217,9 @@ class FlatlandRemoteClient(object): if self.verbose: print("Current env path : ", test_env_file_path) self.current_env_path = test_env_file_path - self.env = RailEnv( - width=1, - height=1, - rail_generator=rail_from_file(test_env_file_path), - schedule_generator=schedule_from_file(test_env_file_path), - obs_builder_object=obs_builder_object - ) + self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path), + schedule_generator=schedule_from_file(test_env_file_path), + obs_builder_object=obs_builder_object) time_start = time.time() local_observation, info = self.env.reset( @@ -246,8 +242,8 @@ class FlatlandRemoteClient(object): _request['type'] = messages.FLATLAND_RL.ENV_STEP _request['payload'] = {} _request['payload']['action'] = action - - # Relay the action in a non-blocking way to the server + + # Relay the action in a non-blocking way to the server # so that it can start doing an env.step on it in ~ parallel self._remote_request(_request, blocking=False) diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 05601455549a2e1bf44f7bd9041229ba8d2cb80e..ce4cb8cffef8c93fddb79e92d06f56fd92e144b8 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -273,7 +273,7 @@ class FlatlandRemoteEvaluationService: ) if self.verbose: print("Received Request : ", command) - + message_queue_latency = time.time() - command["timestamp"] self.update_running_mean_stats("message_queue_latency", message_queue_latency) return command @@ -335,13 +335,9 @@ class FlatlandRemoteEvaluationService: test_env_file_path ) del self.env - self.env = RailEnv( - width=1, - height=1, - rail_generator=rail_from_file(test_env_file_path), - schedule_generator=schedule_from_file(test_env_file_path), - obs_builder_object=DummyObservationBuilder() - ) + self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path), + schedule_generator=schedule_from_file(test_env_file_path), + obs_builder_object=DummyObservationBuilder()) if self.begin_simulation: # If begin simulation has already been initialized diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index f8c9afd0358d42c2829dc9b7c1fd7f3ad5198a3e..c309f9eb3b56c82b872b1842f30eace25a70026a 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -24,10 +24,7 @@ class EditorMVC(object): """ Create an Editor MVC assembly around a railenv, or create one if None. """ if env is None: - env = RailEnv(width=10, - height=10, - rail_generator=empty_rail_generator(), - number_of_agents=0, + env = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) env.reset() @@ -669,11 +666,8 @@ class EditorModel(object): fnMethod = complex_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time())) if env is None: - self.env = RailEnv(width=self.regen_size_width, - height=self.regen_size_height, - rail_generator=fnMethod, - number_of_agents=nAgents, - obs_builder_object=TreeObsForRailEnv(max_depth=2)) + self.env = RailEnv(width=self.regen_size_width, height=self.regen_size_height, rail_generator=fnMethod, + number_of_agents=nAgents, obs_builder_object=TreeObsForRailEnv(max_depth=2)) else: self.env = env self.env.reset(regenerate_rail=True) diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 3bed89b8ce0947c86593e2f1680ef6082f321d84..22cea8280d377b7b7b8a118a4ba1fe3d35e972a7 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -25,14 +25,10 @@ def test_walker(): rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, - predictor=ShortestPathPredictorForRailEnv(max_depth=10)), - ) + predictor=ShortestPathPredictorForRailEnv(max_depth=10))) # reset to initialize agents_static env.reset() diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index e70012f8e4e77017c7dde4c3f1287e4d3bf72278..a573e55d0eef96d30189b483d6478b84653e1244 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -16,14 +16,10 @@ np.random.seed(1) def test_initial_status(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - remove_agents_at_target=False - ) + remove_agents_at_target=False) env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( @@ -127,14 +123,10 @@ def test_initial_status(): def test_status_done_remove(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - remove_agents_at_target=True - ) + remove_agents_at_target=True) env.reset() set_penalties_for_replay(env) diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 0913e45959d08230a815c33d98fb6de8eb99d956..8bc7235edbbed51d65818b1c4de5197b7455ddbe 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -69,13 +69,9 @@ def check_path(env, rail, position, direction, target, expected, rendering=False def test_path_exists(rendering=False): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) # reset to initialize agents_static env.reset() @@ -135,13 +131,9 @@ def test_path_exists(rendering=False): def test_path_not_exists(rendering=False): rail, rail_map = make_simple_rail_unconnected() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) # reset to initialize agents_static env.reset() diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index f425636467ec7cefa0169db006122999b862308a..5543f3912aa4aec750ad63024c7371b02de30309 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -20,11 +20,8 @@ from flatland.utils.simple_rail import make_simple_rail def test_global_obs(): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) global_obs, info = env.reset() @@ -95,13 +92,9 @@ def _step_along_shortest_path(env, obs_builder, rail): def test_reward_function_conflict(rendering=False): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=2, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) obs_builder: TreeObsForRailEnv = env.obs_builder # initialize agents_static env.reset() @@ -176,14 +169,10 @@ def test_reward_function_conflict(rendering=False): def test_reward_function_waiting(rendering=False): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=2, + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - remove_agents_at_target=False - ) + remove_agents_at_target=False) obs_builder: TreeObsForRailEnv = env.obs_builder # initialize agents_static env.reset() diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 280d1d1143d06b9b00832b9c3eb6cbf4add0ffb2..45e0bdda32fdbf4135c9400d655e49f146bce08c 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -21,13 +21,9 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make def test_dummy_predictor(rendering=False): rail, rail_map = make_simple_rail2() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10))) # reset to initialize agents_static env.reset() @@ -113,13 +109,9 @@ def test_dummy_predictor(rendering=False): def test_shortest_path_predictor(rendering=False): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) # reset to initialize agents_static env.reset() @@ -251,13 +243,9 @@ def test_shortest_path_predictor(rendering=False): def test_shortest_path_predictor_conflicts(rendering=False): rail, rail_map = make_invalid_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=2, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) # initialize agents_static env.reset() diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index dc4c78f9a6796d8eef3cfbeb4c54409f14406415..e6550f17aad79bc6685e716249d790aa0acb8bf7 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -30,8 +30,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), - schedule_generator=complex_schedule_generator(), - number_of_agents=2) + schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents_static[0].position agent_1_dir = env.agents_static[0].direction @@ -78,11 +77,8 @@ def test_rail_environment_single_agent(): rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(width=3, - height=3, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) for _ in range(200): @@ -155,11 +151,9 @@ def test_dead_end(): transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], + rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # We try the configuration in the 4 directions: @@ -180,11 +174,9 @@ def test_dead_end(): transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], + rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() @@ -198,13 +190,9 @@ def test_dead_end(): def test_get_entry_directions(): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() def _assert(position, expected): @@ -236,13 +224,9 @@ def test_rail_env_reset(): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=3, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() env.save(file_name) dist_map_shape = np.shape(env.distance_map.get()) @@ -250,13 +234,9 @@ def test_rail_env_reset(): rails_initial = env.rail.grid agents_initial = env.agents - env2 = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env2.reset(False, False, False) rails_loaded = env2.rail.grid agents_loaded = env2.agents @@ -264,13 +244,9 @@ def test_rail_env_reset(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded - env3 = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env3.reset(False, True, False) rails_loaded = env3.rail.grid agents_loaded = env3.agents @@ -278,13 +254,9 @@ def test_rail_env_reset(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded - env4 = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env4.reset(True, False, False) rails_loaded = env4.rail.grid agents_loaded = env4.agents diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index dd64d370077ab12950f0189065c15652e6ad1c6d..344739b798dd2d36136ff0c35698ce0025fc781d 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -16,13 +16,9 @@ from flatland.utils.simple_rail import make_disconnected_simple_rail def test_get_shortest_paths_unreachable(): rail, rail_map = make_disconnected_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10))) env.reset() # set the initial position diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index a1d0fb17ffbf48f77dbad5d7a01acc20e56e30f6..1502ab34c02647d7dda3aed9c3037597ac3bfcc0 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1,9 +1,11 @@ import random - -import numpy as np import unittest import warnings + +import numpy as np + from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator @@ -14,17 +16,13 @@ from flatland.utils.rendertools import RenderTool def test_sparse_rail_generator(): np.random.seed(0) random.seed(0) - env = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(max_num_cities=10, - max_rails_between_cities=3, - seed=5, - grid_mode=False - ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv() - ) + env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10, + max_rails_between_cities=3, + seed=5, + grid_mode=False + ), + schedule_generator=sparse_schedule_generator(), number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) env.reset(False, False, True) # for r in range(env.height): # for c in range (env.width): @@ -554,17 +552,13 @@ def test_sparse_rail_generator_deterministic(): 1. / 3.: 0., # Slow commuter train 1. / 4.: 0.} # Slow freight train - env = RailEnv(width=25, - height=30, - rail_generator=sparse_rail_generator(max_num_cities=5, - max_rails_between_cities=3, - seed=215545, # Random seed - grid_mode=True - ), - schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=1, - stochastic_data=stochastic_data, # Malfunction data generator - ) + env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, + max_rails_between_cities=3, + seed=215545, # Random seed + grid_mode=True + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() # for r in range(env.height): # for c in range(env.width): @@ -1323,42 +1317,30 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]" - def test_rail_env_action_required_info(): - np.random.seed(0) random.seed(0) speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train - env_always_action = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator( - max_num_cities=10, - max_rails_between_cities=3, - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes - ), - schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), - remove_agents_at_target=False) + env_always_action = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( + max_num_cities=10, + max_rails_between_cities=3, + seed=5, # Random seed + grid_mode=False # Ordered distribution of nodes + ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False) np.random.seed(0) random.seed(0) - env_only_if_action_required = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator( - max_num_cities=10, - max_rails_between_cities=3, - seed=5, # Random seed - grid_mode=False - # Ordered distribution of nodes - ), - schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), - remove_agents_at_target=False) + env_only_if_action_required = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( + max_num_cities=10, + max_rails_between_cities=3, + seed=5, # Random seed + grid_mode=False + # Ordered distribution of nodes + ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False) env_renderer = RenderTool(env_always_action, gl="PILSVG", ) env_always_action.reset(False, False, True) @@ -1418,17 +1400,14 @@ def test_rail_env_malfunction_speed_info(): 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 10 # Max duration of malfunction } - env = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(max_num_cities=10, - max_rails_between_cities=3, - seed=5, - grid_mode=False - ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=10, + env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10, + max_rails_between_cities=3, + seed=5, + grid_mode=False + ), + schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv(), - stochastic_data=stochastic_data) + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset(False, False, True) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -1458,17 +1437,12 @@ def test_rail_env_malfunction_speed_info(): def test_sparse_generator_with_too_man_cities_does_not_break_down(): np.random.seed(0) - RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator( - max_num_cities=100, - max_rails_between_cities=3, - seed=5, - grid_mode=False - ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( + max_num_cities=100, + max_rails_between_cities=3, + seed=5, + grid_mode=False + ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) def test_sparse_generator_with_illegal_params_aborts(): @@ -1477,29 +1451,21 @@ def test_sparse_generator_with_illegal_params_aborts(): """ np.random.seed(0) with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, SystemExit): - RailEnv(width=6, - height=6, - rail_generator=sparse_rail_generator( - max_num_cities=100, - max_rails_between_cities=3, - seed=5, - grid_mode=False - ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=10, + RailEnv(width=6, height=6, rail_generator=sparse_rail_generator( + max_num_cities=100, + max_rails_between_cities=3, + seed=5, + grid_mode=False + ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()).reset() with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, SystemExit): - RailEnv(width=60, - height=60, - rail_generator=sparse_rail_generator( - max_num_cities=1, - max_rails_between_cities=3, - seed=5, - grid_mode=False - ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=10, + RailEnv(width=60, height=60, rail_generator=sparse_rail_generator( + max_num_cities=1, + max_rails_between_cities=3, + seed=5, + grid_mode=False + ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()).reset() @@ -1514,17 +1480,12 @@ def test_sparse_generator_changes_to_grid_mode(): for test_run in range(10): with warnings.catch_warnings(record=True) as w: - RailEnv(width=10, - height=20, - rail_generator=sparse_rail_generator( - max_num_cities=100, - max_rails_between_cities=2, - max_rails_in_city=2, - seed=5, - grid_mode=False - ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=10, + RailEnv(width=10, height=20, rail_generator=sparse_rail_generator( + max_num_cities=100, + max_rails_between_cities=2, + max_rails_in_city=2, + seed=5, + grid_mode=False + ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()).reset() assert "[WARNING]" in str(w[-1].message) - diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index e18685f6c68f757f50075ebf206423b562514dd6..14f9b6c0295306448d27a83444e3dd6496485cf1 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -8,6 +8,7 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator @@ -72,14 +73,9 @@ def test_malfunction_process(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=SingleAgentNavigationObs() - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static obs, info = env.reset(False, False, True, random_seed=10) @@ -126,14 +122,9 @@ def test_malfunction_process_statistically(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=10, - stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=SingleAgentNavigationObs() - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=10, + obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static env.reset(True, True, False, random_seed=10) @@ -173,14 +164,9 @@ def test_malfunction_before_entry(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=1), # seed 12 - number_of_agents=10, - random_seed=1, - stochastic_data=stochastic_data, # Malfunction data generator - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=1), number_of_agents=10, + malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1) # reset to initialize agents_static env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) @@ -216,14 +202,9 @@ def test_malfunction_values_and_behavior(): stochastic_data = {'malfunction_rate': 0.001, 'min_duration': 10, 'max_duration': 10} - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), # seed 12 - stochastic_data=stochastic_data, - number_of_agents=1, - random_seed=1, - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1) # reset to initialize agents_static env.reset(False, False, activate_agents=True, random_seed=10) @@ -246,14 +227,9 @@ def test_initial_malfunction(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), - number_of_agents=1, - stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=SingleAgentNavigationObs() - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, + obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static env.reset(False, False, True, random_seed=10) print(env.agents[0].malfunction_data) @@ -318,14 +294,9 @@ def test_initial_malfunction_stop_moving(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=SingleAgentNavigationObs() - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status) @@ -409,13 +380,9 @@ def test_initial_malfunction_do_nothing(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - stochastic_data=stochastic_data, # Malfunction data generator - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static env.reset() set_penalties_for_replay(env) @@ -492,14 +459,9 @@ def tests_random_interference_from_outside(): 'max_duration': 10} rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), # seed 12 - number_of_agents=1, - random_seed=1, - stochastic_data=stochastic_data, # Malfunction data generator - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1) env.reset() # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 @@ -523,14 +485,9 @@ def tests_random_interference_from_outside(): rail, rail_map = make_simple_rail2() random.seed(47) np.random.seed(1234) - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), # seed 12 - number_of_agents=1, - random_seed=1, - stochastic_data=stochastic_data, # Malfunction data generator - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1) env.reset() # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 @@ -565,14 +522,9 @@ def test_last_malfunction_step(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), # seed 12 - number_of_agents=1, - random_seed=1, - stochastic_data=stochastic_data, # Malfunction data generator - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1) env.reset() # reset to initialize agents_static env.agents[0].speed_data['speed'] = 1. / 3. diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index 6e1fb2441d4428b24da3c37d764a7676f3929a2a..6ed92fefb0c81512fc6006cbf44b6d55a274caf3 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -37,11 +37,8 @@ def checkFrozenImage(oRT, sFileImage, resave=False): def test_render_env(save_new_images=False): np.random.seed(100) - oEnv = RailEnv(width=10, height=10, - rail_generator=empty_rail_generator(), - number_of_agents=0, - obs_builder_object=TreeObsForRailEnv(max_depth=2) - ) + oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) oEnv.reset() oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") diff --git a/tests/test_generators.py b/tests/test_generators.py index 1e69223daebd24c52137e12eed9dc43d188a9bbd..83ef0d76360d1c662c9e181c6f88f41adb62ebbf 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -20,11 +20,7 @@ def test_empty_rail_generator(): y_dim = 10 # Check that a random level at with correct parameters is generated - env = RailEnv(width=x_dim, - height=y_dim, - number_of_agents=n_agents, - rail_generator=empty_rail_generator() - ) + env = RailEnv(width=x_dim, height=y_dim, rail_generator=empty_rail_generator(), number_of_agents=n_agents) env.reset() # Check the dimensions assert env.rail.grid.shape == (y_dim, x_dim) @@ -41,11 +37,7 @@ def test_random_rail_generator(): y_dim = 10 # Check that a random level at with correct parameters is generated - env = RailEnv(width=x_dim, - height=y_dim, - number_of_agents=n_agents, - rail_generator=random_rail_generator() - ) + env = RailEnv(width=x_dim, height=y_dim, rail_generator=random_rail_generator(), number_of_agents=n_agents) env.reset() assert env.rail.grid.shape == (y_dim, x_dim) assert env.get_num_agents() == n_agents @@ -59,12 +51,9 @@ def test_complex_rail_generator(): min_dist = 4 # Check that agent number is changed to fit generated level - env = RailEnv(width=x_dim, - height=y_dim, - number_of_agents=n_agents, + env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - schedule_generator=complex_schedule_generator() - ) + schedule_generator=complex_schedule_generator(), number_of_agents=n_agents) env.reset() assert env.get_num_agents() == 2 assert env.rail.grid.shape == (y_dim, x_dim) @@ -72,12 +61,9 @@ def test_complex_rail_generator(): min_dist = 2 * x_dim # Check that no agents are generated when level cannot be generated - env = RailEnv(width=x_dim, - height=y_dim, - number_of_agents=n_agents, + env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - schedule_generator=complex_schedule_generator() - ) + schedule_generator=complex_schedule_generator(), number_of_agents=n_agents) env.reset() assert env.get_num_agents() == 0 assert env.rail.grid.shape == (y_dim, x_dim) @@ -87,12 +73,9 @@ def test_complex_rail_generator(): n_start = 5 n_agents = 5 - env = RailEnv(width=x_dim, - height=y_dim, - number_of_agents=n_agents, + env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - schedule_generator=complex_schedule_generator() - ) + schedule_generator=complex_schedule_generator(), number_of_agents=n_agents) env.reset() assert env.get_num_agents() == n_agents assert env.rail.grid.shape == (y_dim, x_dim) @@ -101,12 +84,8 @@ def test_complex_rail_generator(): def test_rail_from_grid_transition_map(): rail, rail_map = make_simple_rail() n_agents = 3 - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=n_agents - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=n_agents) env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -127,13 +106,9 @@ def tests_rail_from_file(): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=3, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() env.save(file_name) dist_map_shape = np.shape(env.distance_map.get()) @@ -141,13 +116,9 @@ def tests_rail_from_file(): rails_initial = env.rail.grid agents_initial = env.agents - env = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() rails_loaded = env.rail.grid agents_loaded = env.agents @@ -163,13 +134,9 @@ def tests_rail_from_file(): file_name_2 = "test_without_distance_map.pkl" - env2 = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=3, - obs_builder_object=GlobalObsForRailEnv(), - ) + env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), + number_of_agents=3, obs_builder_object=GlobalObsForRailEnv()) env2.reset() env2.save(file_name_2) @@ -177,13 +144,9 @@ def tests_rail_from_file(): rails_initial_2 = env2.rail.grid agents_initial_2 = env2.agents - env2 = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name_2), - schedule_generator=schedule_from_file(file_name_2), - number_of_agents=1, - obs_builder_object=GlobalObsForRailEnv(), - ) + env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) env2.reset() rails_loaded_2 = env2.rail.grid agents_loaded_2 = env2.agents @@ -194,13 +157,9 @@ def tests_rail_from_file(): # Test to save with distance map and load without - env3 = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), - number_of_agents=1, - obs_builder_object=GlobalObsForRailEnv(), - ) + env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) env3.reset() rails_loaded_3 = env3.rail.grid agents_loaded_3 = env3.agents @@ -212,13 +171,9 @@ def tests_rail_from_file(): # Test to save without distance map and load with generating distance map # initialize agents_static - env4 = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name_2), - schedule_generator=schedule_from_file(file_name_2), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2), - ) + env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) env4.reset() rails_loaded_4 = env4.rail.grid agents_loaded_4 = env4.agents diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index d3bcf7779dbc4c8dbafe6e726aeff33c757c25fb..afaf2b7f496ef31c7d5228d1f389c254cc16df2e 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -22,16 +22,13 @@ def test_get_global_observation(): 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train - env = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(max_num_cities=6, - max_rails_between_cities=4, - seed=15, - grid_mode=False - ), - schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=GlobalObsForRailEnv()) + env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=6, + max_rails_between_cities=4, + seed=15, + grid_mode=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) env.reset() obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index f83990cc39bf73e50719b2291006eed68d1d1360..f5fc66662149816c9345be738d8bd30a8bc8306f 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -49,11 +49,9 @@ class RandomAgent: def test_multi_speed_init(): - env = RailEnv(width=50, - height=50, + env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=1), - schedule_generator=complex_schedule_generator(), + seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) @@ -97,13 +95,9 @@ def test_multi_speed_init(): def test_multispeed_actions_no_malfunction_no_blocking(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) @@ -201,13 +195,9 @@ def test_multispeed_actions_no_malfunction_no_blocking(): def test_multispeed_actions_no_malfunction_blocking(): """The second agent blocks the first because it is slower.""" rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=2, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) test_configs = [ @@ -389,13 +379,9 @@ def test_multispeed_actions_no_malfunction_blocking(): def test_multispeed_actions_malfunction_no_blocking(): """Test on a single agent whether action on cell exit work correctly despite malfunction.""" rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) @@ -527,13 +513,9 @@ def test_multispeed_actions_malfunction_no_blocking(): def test_multispeed_actions_no_malfunction_invalid_actions(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index a1981a9a0d2ed80b98f8d1e1fc8a49e14624afae..4ce04e5e5e07ca5b5e864efffea951221322f306 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -14,12 +14,8 @@ def test_random_seeding(): # Move target to unreachable position in order to not interfere with test for idx in range(100): - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=12), - number_of_agents=10 - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=12), number_of_agents=10) env.reset(True, True, False, random_seed=1) env.agents[0].target = (0, 0) @@ -52,21 +48,13 @@ def test_seeding_and_observations(): # Make two seperate envs with different observation builders # Global Observation - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=12), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv() - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) # Tree Observation - env2 = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=12), - number_of_agents=10, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) - ) + env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset(False, False, False, random_seed=12) env2.reset(False, False, False, random_seed=12) @@ -118,24 +106,14 @@ def test_seeding_and_malfunction(): # Make two seperate envs with different and see if the exhibit the same malfunctions # Global Observation for tests in range(1, 100): - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), - stochastic_data=stochastic_data, # Malfunction data generator - ) + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) # Tree Observation - env2 = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), - stochastic_data=stochastic_data, # Malfunction data generator - ) + env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) env.reset(True, False, True, random_seed=tests) env2.reset(True, False, True, random_seed=tests) diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index 1fcf3b3ef0b7cc176d345fa547e91eeeef0a05bd..c1c03c3676cb4b6847265af02432ba8b37ca4cfe 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -18,11 +18,9 @@ def test_speed_initialization_helper(): def test_rail_env_speed_intializer(): speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} - env = RailEnv(width=50, - height=50, + env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=1), - schedule_generator=complex_schedule_generator(), + seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=10) env.reset() actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))