From 7209aad8dac944d1d9a88b27760c30cba17f5a83 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 28 Aug 2019 13:40:56 +0200 Subject: [PATCH] merge #141 renamed agent_generator* to schedule_generator* --- examples/complex_rail_benchmark.py | 4 +- examples/custom_observation_example.py | 6 +- examples/custom_railmap_example.py | 2 +- examples/debugging_example_DELETE.py | 4 +- examples/flatland_2_0_example.py | 4 +- examples/simple_example_3.py | 4 +- examples/training_example.py | 4 +- flatland/cli.py | 4 +- flatland/envs/grid4_generators_utils.py | 78 ------------------- flatland/envs/rail_env.py | 14 ++-- flatland/envs/schedule_generators.py | 6 +- tests/test_distance_map.py | 4 +- tests/test_flatland_envs_observations.py | 8 +- tests/test_flatland_envs_predictions.py | 8 +- tests/test_flatland_envs_rail_env.py | 10 +-- ...est_flatland_envs_sparse_rail_generator.py | 4 +- tests/test_flatland_malfunction.py | 4 +- tests/test_multi_speed.py | 4 +- tests/test_speed_classes.py | 4 +- tests/tests_generators.py | 22 +++--- 20 files changed, 60 insertions(+), 138 deletions(-) diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index a8d5a78d..49e550b1 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -5,7 +5,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator def run_benchmark(): @@ -16,7 +16,7 @@ def run_benchmark(): # Example generate a random rail env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=5) n_trials = 20 diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 4f3a18e5..8b1de6aa 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(100) @@ -93,7 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) @@ -204,7 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor) env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=CustomObsBuilder) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index f3350d69..04da6690 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -29,7 +29,7 @@ def custom_rail_generator() -> RailGenerator: return generator -def custom_agent_generator() -> ScheduleGenerator: +def custom_schedule_generator() -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: agents_positions = [] agents_direction = [] diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index c4eae5f4..50ea74b8 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -6,7 +6,7 @@ import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) @@ -61,7 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=14, height=14, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 9f4d62cf..74032a9b 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -4,7 +4,7 @@ from flatland.envs.generators import sparse_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -32,7 +32,7 @@ env = RailEnv(width=20, realistic_mode=True, enhance_intersection=True ), - agent_generator=sparse_rail_generator_agents_placer(), + schedule_generator=sparse_schedule_generator(), number_of_agents=5, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 66c44ed1..6df6d4af 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -5,7 +5,7 @@ import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) @@ -14,7 +14,7 @@ np.random.seed(1) env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) diff --git a/examples/training_example.py b/examples/training_example.py index 3c1cddd0..df93479f 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -17,7 +17,7 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) diff --git a/flatland/cli.py b/flatland/cli.py index b2509287..47c450db 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -10,7 +10,7 @@ import redis from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.evaluators.service import FlatlandRemoteEvaluationService from flatland.utils.rendertools import RenderTool @@ -26,7 +26,7 @@ def demo(args=None): nr_extra=1, min_dist=8, max_dist=99999), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=5) env._max_episode_steps = int(15 * (env.width + env.height)) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index d6046d6b..996bd73a 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -193,81 +193,3 @@ def connect_to_nodes(rail_trans, rail_array, start, end): current_dir = new_dir return path - - -def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): - """ - Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). - - TODO: add extensive documentation, as users may need this function to simplify their custom level generators. - """ - - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = rail.get_transitions(node[0][0], node[0][1], node[1]) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = rail.get_full_transitions(node[0][0], node[0][1]) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - valid_positions = [] - for r in range(rail.height): - for c in range(rail.width): - if rail.get_full_transitions(r, c) > 0: - valid_positions.append((r, c)) - - re_generate = True - while re_generate: - agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - agents_direction = [0] * num_agents - re_generate = False - for i in range(num_agents): - valid_movements = [] - for direction in range(4): - position = agents_position[i] - moves = rail.get_transitions(position[0], position[1], direction) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] - - return agents_position, agents_direction, agents_target diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index df0ee1f7..a61ef022 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -15,7 +15,7 @@ from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, ScheduleGenerator +from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator m.patch() @@ -94,7 +94,7 @@ class RailEnv(Environment): width, height, rail_generator: RailGenerator = random_rail_generator(), - agent_generator: ScheduleGenerator = get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), max_episode_steps=None, @@ -110,10 +110,10 @@ class RailEnv(Environment): height and agents handles of a rail environment, along with the number of times the env has been reset, and returns a GridTransitionMap object and a list of starting positions, targets, and initial orientations for agent handle. - The rail_generator can pass a distance map in the hints or information for specific agent_generators. + The rail_generator can pass a distance map in the hints or information for specific schedule_generators. Implementations can be found in flatland/envs/rail_generators.py - agent_generator : function - The agent_generator function is a function that takes the grid, the number of agents and optional hints + schedule_generator : function + The schedule_generator function is a function that takes the grid, the number of agents and optional hints and returns a list of starting positions, targets, initial orientations and speed for all agent handles. Implementations can be found in flatland/envs/schedule_generators.py width : int @@ -134,7 +134,7 @@ class RailEnv(Environment): """ self.rail_generator: RailGenerator = rail_generator - self.agent_generator: ScheduleGenerator = agent_generator + self.schedule_generator: ScheduleGenerator = schedule_generator self.rail_generator = rail_generator self.rail: GridTransitionMap = None self.width = width @@ -237,7 +237,7 @@ class RailEnv(Environment): if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] self.agents_static = EnvAgentStatic.from_lists( - *self.agent_generator(self.rail, self.get_num_agents(), hints=agents_hints)) + *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) self.restart_agents() diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index ef1f9666..0ebc6c71 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -38,7 +38,7 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) -def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: +def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): start_goal = hints['start_goal'] start_dir = hints['start_dir'] @@ -56,7 +56,7 @@ def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] return generator -def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: +def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] agent_start_targets_nodes = hints['agent_start_targets_nodes'] @@ -111,7 +111,7 @@ def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = return generator -def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: +def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: """ Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 583830f7..e5e89f76 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.schedule_generators import random_schedule_generator def test_walker(): @@ -28,7 +28,7 @@ def test_walker(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 63f4a9a5..c96e8db0 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -9,7 +9,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail @@ -22,7 +22,7 @@ def test_global_obs(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -92,7 +92,7 @@ def test_reward_function_conflict(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -171,7 +171,7 @@ def test_reward_function_waiting(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 65894b50..09f7e5e6 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail @@ -22,7 +22,7 @@ def test_dummy_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) @@ -113,7 +113,7 @@ def test_shortest_path_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -233,7 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 656059ac..d5dc3ac7 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -10,7 +10,7 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator """Tests for `flatland` package.""" @@ -27,7 +27,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents_static[0].position @@ -79,7 +79,7 @@ def test_rail_environment_single_agent(): rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -159,7 +159,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -204,7 +204,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 5f0cc81e..db7cac61 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1,7 +1,7 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool @@ -17,7 +17,7 @@ def test_sparse_rail_generator(): seed=5, # Random seed realistic_mode=False # Ordered distribution of nodes ), - agent_generator=sparse_rail_generator_agents_placer(), + schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 0877e63c..eaf782df 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -3,7 +3,7 @@ import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -63,7 +63,7 @@ def test_malfunction_process(): height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index dff0d2c2..8de36c81 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -2,7 +2,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator np.random.seed(1) @@ -48,7 +48,7 @@ def test_multi_speed_init(): height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index d2754c2a..ff5ee56a 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -3,7 +3,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import speed_initialization_helper, complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator def test_speed_initialization_helper(): @@ -22,7 +22,7 @@ def test_rail_env_speed_intializer(): height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=10) env.reset() actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 31cc8d1f..610022ca 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -8,7 +8,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ random_rail_generator, empty_rail_generator -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer, \ +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \ agents_from_file from flatland.utils.simple_rail import make_simple_rail @@ -61,7 +61,7 @@ def test_complex_rail_generator(): height=y_dim, number_of_agents=n_agents, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - agent_generator=complex_rail_generator_agents_placer() + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 2 assert env.rail.grid.shape == (y_dim, x_dim) @@ -73,7 +73,7 @@ def test_complex_rail_generator(): height=y_dim, number_of_agents=n_agents, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - agent_generator=complex_rail_generator_agents_placer() + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 0 assert env.rail.grid.shape == (y_dim, x_dim) @@ -87,7 +87,7 @@ def test_complex_rail_generator(): height=y_dim, number_of_agents=n_agents, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - agent_generator=complex_rail_generator_agents_placer() + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == n_agents assert env.rail.grid.shape == (y_dim, x_dim) @@ -99,7 +99,7 @@ def test_rail_from_grid_transition_map(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=n_agents ) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -124,7 +124,7 @@ def tests_rail_from_file(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -137,7 +137,7 @@ def tests_rail_from_file(): env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - agent_generator=agents_from_file(file_name), + schedule_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -159,7 +159,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv(), ) @@ -173,7 +173,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - agent_generator=agents_from_file(file_name_2), + schedule_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -190,7 +190,7 @@ def tests_rail_from_file(): env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - agent_generator=agents_from_file(file_name), + schedule_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -208,7 +208,7 @@ def tests_rail_from_file(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - agent_generator=agents_from_file(file_name_2), + schedule_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), ) -- GitLab