diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index a8d5a78d03a3003c6e095ea5c0662c7777ff38e0..49e550b195ffb15e6554413069369378e80e5f82 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -5,7 +5,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator def run_benchmark(): @@ -16,7 +16,7 @@ def run_benchmark(): # Example generate a random rail env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=5) n_trials = 20 diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 4f3a18e52586662e69557adf3c986e518332c04d..8b1de6aa4e303469d30983d30333fbfda89c1d1e 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(100) @@ -93,7 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) @@ -204,7 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor) env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=CustomObsBuilder) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index f3350d697fa76849e5cfc07164825e4675dc5380..04da66904fda1a58847a4acc510d7fc4e4e86887 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -29,7 +29,7 @@ def custom_rail_generator() -> RailGenerator: return generator -def custom_agent_generator() -> ScheduleGenerator: +def custom_schedule_generator() -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: agents_positions = [] agents_direction = [] diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index c4eae5f42e323dc6fe39d265968461aa59fcab81..50ea74b84ac9851e88e48bcd32b914e69bc7dd34 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -6,7 +6,7 @@ import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) @@ -61,7 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=14, height=14, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 9f4d62cf39c569ece6498e149743e17b17905eaf..74032a9b766285715bc87037b46a9cb3332d0598 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -4,7 +4,7 @@ from flatland.envs.generators import sparse_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -32,7 +32,7 @@ env = RailEnv(width=20, realistic_mode=True, enhance_intersection=True ), - agent_generator=sparse_rail_generator_agents_placer(), + schedule_generator=sparse_schedule_generator(), number_of_agents=5, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 66c44ed15a61a6e17d9e53ffe365396b26ca3b9b..6df6d4af3076b3d9659aadbd55296b667dc7d6db 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -5,7 +5,7 @@ import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) @@ -14,7 +14,7 @@ np.random.seed(1) env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) diff --git a/examples/training_example.py b/examples/training_example.py index 3c1cddd03337e63b234f83cc98aca4993cd26de8..df93479f5a5ee05abfcb1a98b07ef052bffc2bd4 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -17,7 +17,7 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) diff --git a/flatland/cli.py b/flatland/cli.py index b2509287185a1582573a6eba347ef7d6ec27f10f..47c450dba803fac17bc13979663ef04e4c0db899 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -10,7 +10,7 @@ import redis from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.evaluators.service import FlatlandRemoteEvaluationService from flatland.utils.rendertools import RenderTool @@ -26,7 +26,7 @@ def demo(args=None): nr_extra=1, min_dist=8, max_dist=99999), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=5) env._max_episode_steps = int(15 * (env.width + env.height)) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index d6046d6b8867988d40df30041241f17b685bc83a..996bd73ad9de598eb162a937c135681675119ad3 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -193,81 +193,3 @@ def connect_to_nodes(rail_trans, rail_array, start, end): current_dir = new_dir return path - - -def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): - """ - Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). - - TODO: add extensive documentation, as users may need this function to simplify their custom level generators. - """ - - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = rail.get_transitions(node[0][0], node[0][1], node[1]) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = rail.get_full_transitions(node[0][0], node[0][1]) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - valid_positions = [] - for r in range(rail.height): - for c in range(rail.width): - if rail.get_full_transitions(r, c) > 0: - valid_positions.append((r, c)) - - re_generate = True - while re_generate: - agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - agents_direction = [0] * num_agents - re_generate = False - for i in range(num_agents): - valid_movements = [] - for direction in range(4): - position = agents_position[i] - moves = rail.get_transitions(position[0], position[1], direction) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] - - return agents_position, agents_direction, agents_target diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index df0ee1f7f91c80fe160120eb527b40196c15d870..a61ef02207174d04489b5311dc042b7c06db1412 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -15,7 +15,7 @@ from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, ScheduleGenerator +from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator m.patch() @@ -94,7 +94,7 @@ class RailEnv(Environment): width, height, rail_generator: RailGenerator = random_rail_generator(), - agent_generator: ScheduleGenerator = get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), max_episode_steps=None, @@ -110,10 +110,10 @@ class RailEnv(Environment): height and agents handles of a rail environment, along with the number of times the env has been reset, and returns a GridTransitionMap object and a list of starting positions, targets, and initial orientations for agent handle. - The rail_generator can pass a distance map in the hints or information for specific agent_generators. + The rail_generator can pass a distance map in the hints or information for specific schedule_generators. Implementations can be found in flatland/envs/rail_generators.py - agent_generator : function - The agent_generator function is a function that takes the grid, the number of agents and optional hints + schedule_generator : function + The schedule_generator function is a function that takes the grid, the number of agents and optional hints and returns a list of starting positions, targets, initial orientations and speed for all agent handles. Implementations can be found in flatland/envs/schedule_generators.py width : int @@ -134,7 +134,7 @@ class RailEnv(Environment): """ self.rail_generator: RailGenerator = rail_generator - self.agent_generator: ScheduleGenerator = agent_generator + self.schedule_generator: ScheduleGenerator = schedule_generator self.rail_generator = rail_generator self.rail: GridTransitionMap = None self.width = width @@ -237,7 +237,7 @@ class RailEnv(Environment): if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] self.agents_static = EnvAgentStatic.from_lists( - *self.agent_generator(self.rail, self.get_num_agents(), hints=agents_hints)) + *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) self.restart_agents() diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index ef1f9666cc57fb280015f4f6fed7b7dc019ecd89..0ebc6c71c17db308789a4baf0ec99729ec9991e8 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -38,7 +38,7 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) -def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: +def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): start_goal = hints['start_goal'] start_dir = hints['start_dir'] @@ -56,7 +56,7 @@ def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] return generator -def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: +def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] agent_start_targets_nodes = hints['agent_start_targets_nodes'] @@ -111,7 +111,7 @@ def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = return generator -def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: +def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: """ Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 583830f79146372bfaa8a7d0739688d88781d38b..e5e89f76428bb881d0f72aa60aada97ab02167a5 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.schedule_generators import random_schedule_generator def test_walker(): @@ -28,7 +28,7 @@ def test_walker(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 63f4a9a52c147f0bd8181181d84ff8b8ead529c6..c96e8db00fe721f42667aed4833d034a47f19156 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -9,7 +9,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail @@ -22,7 +22,7 @@ def test_global_obs(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -92,7 +92,7 @@ def test_reward_function_conflict(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -171,7 +171,7 @@ def test_reward_function_waiting(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 65894b505c000098bd5dc798e9ac0cfd1aed09e1..09f7e5e67a15c55b5070ac8679e43ecc9a14b9da 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail @@ -22,7 +22,7 @@ def test_dummy_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) @@ -113,7 +113,7 @@ def test_shortest_path_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -233,7 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 656059ac243dfe0cd5386648603688ec0de4546b..d5dc3ac7af4be6ebd8c5cbeaf705bb710d36d138 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -10,7 +10,7 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator """Tests for `flatland` package.""" @@ -27,7 +27,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents_static[0].position @@ -79,7 +79,7 @@ def test_rail_environment_single_agent(): rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -159,7 +159,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -204,7 +204,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 5f0cc81e660f885f2d5d6ee2357d85da7d7c903c..db7cac61f4cf3bec4a330694c1864ef7d82bd076 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1,7 +1,7 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool @@ -17,7 +17,7 @@ def test_sparse_rail_generator(): seed=5, # Random seed realistic_mode=False # Ordered distribution of nodes ), - agent_generator=sparse_rail_generator_agents_placer(), + schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 0877e63c1f049a35dbfd5b3686810d0ddb1833bc..eaf782df3255ecfc6ebaa7078935f485497ed359 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -3,7 +3,7 @@ import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -63,7 +63,7 @@ def test_malfunction_process(): height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index dff0d2c2cacc8debdf7aa95867b004ba965373c2..8de36c81e4a13c0b7e7e5e556ad79234503ad31a 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -2,7 +2,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import complex_schedule_generator np.random.seed(1) @@ -48,7 +48,7 @@ def test_multi_speed_init(): height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index d2754c2a1ff356d91c90fd473184f6d23639fd35..ff5ee56a308ce19559d079b716bde90ad65baf11 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -3,7 +3,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import speed_initialization_helper, complex_rail_generator_agents_placer +from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator def test_speed_initialization_helper(): @@ -22,7 +22,7 @@ def test_rail_env_speed_intializer(): height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - agent_generator=complex_rail_generator_agents_placer(), + schedule_generator=complex_schedule_generator(), number_of_agents=10) env.reset() actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 31cc8d1f29166b3366e606e2fb4c4e2d04567275..610022cafe12fccb2cbbd5da57006e61c89faf28 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -8,7 +8,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ random_rail_generator, empty_rail_generator -from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer, \ +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \ agents_from_file from flatland.utils.simple_rail import make_simple_rail @@ -61,7 +61,7 @@ def test_complex_rail_generator(): height=y_dim, number_of_agents=n_agents, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - agent_generator=complex_rail_generator_agents_placer() + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 2 assert env.rail.grid.shape == (y_dim, x_dim) @@ -73,7 +73,7 @@ def test_complex_rail_generator(): height=y_dim, number_of_agents=n_agents, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - agent_generator=complex_rail_generator_agents_placer() + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 0 assert env.rail.grid.shape == (y_dim, x_dim) @@ -87,7 +87,7 @@ def test_complex_rail_generator(): height=y_dim, number_of_agents=n_agents, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - agent_generator=complex_rail_generator_agents_placer() + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == n_agents assert env.rail.grid.shape == (y_dim, x_dim) @@ -99,7 +99,7 @@ def test_rail_from_grid_transition_map(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=n_agents ) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -124,7 +124,7 @@ def tests_rail_from_file(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -137,7 +137,7 @@ def tests_rail_from_file(): env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - agent_generator=agents_from_file(file_name), + schedule_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -159,7 +159,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv(), ) @@ -173,7 +173,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - agent_generator=agents_from_file(file_name_2), + schedule_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -190,7 +190,7 @@ def tests_rail_from_file(): env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - agent_generator=agents_from_file(file_name), + schedule_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -208,7 +208,7 @@ def tests_rail_from_file(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - agent_generator=agents_from_file(file_name_2), + schedule_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), )