diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 9143a86334401889fbd9d2494b2f97a8e6ef5435..94f1c2d4b5defbabb08ffaae8757ae98b2332fd9 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -3,7 +3,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file -from flatland.envs.schedule_generators import schedule_from_file +from flatland.envs.line_generators import line_from_file def load_flatland_environment_from_file(file_name: str, @@ -33,7 +33,7 @@ def load_flatland_environment_from_file(file_name: str, max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)) environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), - schedule_generator=schedule_from_file(file_name, load_from_package), + schedule_generator=line_from_file(file_name, load_from_package), number_of_agents=1, obs_builder_object=obs_builder_object, record_steps=record_steps, diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index b31ec52524599cc026bd086acdacf4e69c8c2774..5471cd47e99b2b5f49031c2bebe3ab470331e32b 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -15,7 +15,7 @@ import flatland from flatland.envs.malfunction_generators import malfunction_from_file from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file -from flatland.envs.schedule_generators import schedule_from_file +from flatland.envs.line_generators import line_from_file from flatland.evaluators import messages from flatland.core.env_observation_builder import DummyObservationBuilder diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 03b94380f232773044a2733802e6df4ef9d1918f..8e3b4155158d4fbce32a6e6868f0a5b72c85127b 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -26,7 +26,7 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_file from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file -from flatland.envs.schedule_generators import schedule_from_file +from flatland.envs.line_generators import line_from_file from flatland.evaluators import aicrowd_helpers from flatland.evaluators import messages from flatland.utils.rendertools import RenderTool diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 9a03eb8a9f89e1edaac558e563a3c0544b4d6b5c..815ecbcd7c1e091681b67d5c095b890ae6ab798a 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -6,7 +6,7 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_trainrun_data_structures import Waypoint -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.utils.simple_rail import make_simple_rail @@ -17,7 +17,7 @@ def test_action_plan(rendering: bool = False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=77), + line_generator=random_line_generator(seed=77), number_of_agents=2, obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=True @@ -34,25 +34,25 @@ def test_action_plan(rendering: bool = False): for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) - chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), - TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)), - TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)), - TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)), - TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)), - TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)), - TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)), - TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)), - TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)), - TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))], - 1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)), - TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)), - TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)), - TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)), - TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)), - TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)), - TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)), - TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)), - TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]} + chosen_path_dict = {0: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), + TrainrunWaypoint(lined_at=2, waypoint=Waypoint(position=(3, 1), direction=1)), + TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 2), direction=1)), + TrainrunWaypoint(lined_at=14, waypoint=Waypoint(position=(3, 3), direction=1)), + TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(3, 4), direction=1)), + TrainrunWaypoint(lined_at=16, waypoint=Waypoint(position=(3, 5), direction=1)), + TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(3, 6), direction=1)), + TrainrunWaypoint(lined_at=18, waypoint=Waypoint(position=(3, 7), direction=1)), + TrainrunWaypoint(lined_at=19, waypoint=Waypoint(position=(3, 8), direction=1)), + TrainrunWaypoint(lined_at=20, waypoint=Waypoint(position=(3, 8), direction=5))], + 1: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 8), direction=3)), + TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 7), direction=3)), + TrainrunWaypoint(lined_at=5, waypoint=Waypoint(position=(3, 6), direction=3)), + TrainrunWaypoint(lined_at=7, waypoint=Waypoint(position=(3, 5), direction=3)), + TrainrunWaypoint(lined_at=9, waypoint=Waypoint(position=(3, 4), direction=3)), + TrainrunWaypoint(lined_at=11, waypoint=Waypoint(position=(3, 3), direction=3)), + TrainrunWaypoint(lined_at=13, waypoint=Waypoint(position=(2, 3), direction=0)), + TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(1, 3), direction=0)), + TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]} expected_action_plan = [[ # take action to enter the grid ActionPlanElement(0, RailEnvActions.MOVE_FORWARD), diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index c6a96fbefff68c4dbe448fc666e94317729aae6b..850af32dc6dd83580219dc5a608bb954343c27f8 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator def test_walker(): @@ -28,7 +28,7 @@ def test_walker(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 05127379b78d96e8b68be2b82b8210a6ba86546b..4f507415e8acc25a1d931b68f249d568839fd609 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.simple_rail import make_simple_rail from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay @@ -13,7 +13,7 @@ def test_initial_status(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) env.reset() @@ -121,7 +121,7 @@ def test_status_done_remove(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=True) env.reset() diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index a569aa35534385698369980566c426cf72b7bb4b..fc9a49106ca6c1e0eb6c7754b942cb19b24d66a5 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected @@ -70,7 +70,7 @@ def test_path_exists(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -134,7 +134,7 @@ def test_path_not_exists(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 6e5a374d6606800b311205bd86d99600db447b91..19ae530d681c864b87dda6aed3423013e36f3c14 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -10,7 +10,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail @@ -21,7 +21,7 @@ def test_global_obs(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) global_obs, info = env.reset() @@ -93,7 +93,7 @@ def _step_along_shortest_path(env, obs_builder, rail): def test_reward_function_conflict(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=2, + line_generator=random_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) obs_builder: TreeObsForRailEnv = env.obs_builder env.reset() @@ -181,7 +181,7 @@ def test_reward_function_conflict(rendering=False): def test_reward_function_waiting(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=2, + line_generator=random_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) obs_builder: TreeObsForRailEnv = env.obs_builder diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index c649517108597de87dd8195169b4672434590c0f..e573ef7eb05057670a1fdfeaac2cccb7852651ea 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -12,7 +12,7 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_trainrun_data_structures import Waypoint -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail @@ -25,7 +25,7 @@ def test_dummy_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) @@ -116,7 +116,7 @@ def test_shortest_path_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -247,7 +247,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 0c865502f94b624ca9712a20cb83af72b0310357..c531359a2d3d573db32cdec7515155511e5f6d57 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -11,7 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, rail_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, schedule_from_file +from flatland.envs.line_generators import random_line_generator, complex_line_generator, line_from_file from flatland.utils.simple_rail import make_simple_rail from flatland.envs.persistence import RailEnvPersister from flatland.utils.rendertools import RenderTool @@ -38,7 +38,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), - schedule_generator=complex_schedule_generator(), number_of_agents=2) + line_generator=complex_line_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents[0].position agent_1_dir = env.agents[0].direction @@ -68,7 +68,7 @@ def test_save_load(): def test_save_load_mpk(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), - schedule_generator=complex_schedule_generator(), number_of_agents=2) + line_generator=complex_line_generator(), number_of_agents=2) env.reset() os.makedirs("tmp", exist_ok=True) @@ -120,7 +120,7 @@ def test_rail_environment_single_agent(show=False): rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail.grid = rail_map rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) else: rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests") @@ -203,7 +203,7 @@ def test_rail_environment_single_agent(show=False): rail_env.agents[0].direction = 0 - # JW - to avoid problem with random_schedule_generator. + # JW - to avoid problem with random_line_generator. #rail_env.agents[0].position = (1,2) iStep = 0 @@ -246,7 +246,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # We try the configuration in the 4 directions: @@ -269,7 +269,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() @@ -284,7 +284,7 @@ def test_dead_end(): def test_get_entry_directions(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -319,7 +319,7 @@ def test_rail_env_reset(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=3, + line_generator=random_line_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -331,7 +331,7 @@ def test_rail_env_reset(): agents_initial = env.agents #env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - # schedule_generator=schedule_from_file(file_name), number_of_agents=1, + # line_generator=line_from_file(file_name), number_of_agents=1, # obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) #env2.reset(False, False, False) env2, env2_dict = RailEnvPersister.load_new(file_name) @@ -343,7 +343,7 @@ def test_rail_env_reset(): assert agents_initial == agents_loaded env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), number_of_agents=1, + line_generator=line_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env3.reset(False, True, False) rails_loaded = env3.rail.grid @@ -353,7 +353,7 @@ def test_rail_env_reset(): assert agents_initial == agents_loaded env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), number_of_agents=1, + line_generator=line_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env4.reset(True, False, False) rails_loaded = env4.rail.grid diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 84504cf2f2be1867e450e3520d566fb1eee55cf9..b2ee9b015c74f74c651aa6396dee3c6ea5a49ce4 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -9,7 +9,7 @@ from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shor from flatland.envs.rail_env_utils import load_flatland_environment_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_trainrun_data_structures import Waypoint -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives from flatland.envs.persistence import RailEnvPersister @@ -19,7 +19,7 @@ def test_get_shortest_paths_unreachable(): rail, rail_map = make_disconnected_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) env.reset() @@ -238,7 +238,7 @@ def test_get_k_shortest_paths(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index c6e151b36830af23d566837d8fe1f33877bf69c6..3e74d720d4276dae588a3989843ada9d1761671d 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -7,7 +7,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.envs.line_generators import sparse_line_generator from flatland.utils.rendertools import RenderTool @@ -17,7 +17,7 @@ def test_sparse_rail_generator(): seed=5, grid_mode=False ), - schedule_generator=sparse_schedule_generator(), number_of_agents=10, + line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) env.reset(False, False, True) for r in range(env.height): @@ -602,7 +602,7 @@ def test_sparse_rail_generator_deterministic(): seed=215545, # Random seed grid_mode=True ), - schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1) + line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) env.reset() # for r in range(env.height): # for c in range(env.width): @@ -1371,7 +1371,7 @@ def test_rail_env_action_required_info(): max_rails_between_cities=3, seed=5, # Random seed grid_mode=False # Ordered distribution of nodes - ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, + ), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False) env_only_if_action_required = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( @@ -1380,7 +1380,7 @@ def test_rail_env_action_required_info(): seed=5, # Random seed grid_mode=False # Ordered distribution of nodes - ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, + ), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False) env_renderer = RenderTool(env_always_action, gl="PILSVG", ) @@ -1442,7 +1442,7 @@ def test_rail_env_malfunction_speed_info(): seed=5, grid_mode=False ), - schedule_generator=sparse_schedule_generator(), number_of_agents=10, + line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) env.reset(False, False, True) @@ -1476,7 +1476,7 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down(): max_rails_between_cities=3, seed=5, grid_mode=False - ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) + ), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) def test_sparse_generator_with_illegal_params_aborts(): @@ -1489,7 +1489,7 @@ def test_sparse_generator_with_illegal_params_aborts(): max_rails_between_cities=3, seed=5, grid_mode=False - ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, + ), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()).reset() with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, ValueError): @@ -1498,7 +1498,7 @@ def test_sparse_generator_with_illegal_params_aborts(): max_rails_between_cities=3, seed=5, grid_mode=False - ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, + ), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()).reset() @@ -1515,7 +1515,7 @@ def test_sparse_generator_changes_to_grid_mode(): max_rails_in_city=2, seed=15, grid_mode=False - ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, + ), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) for test_run in range(10): with warnings.catch_warnings(record=True) as w: diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index eaa3112708f3f0e5d255b7e454078d9a59e7ca22..53915102c539c8e7703434100634c36fda001049 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -10,7 +10,7 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.simple_rail import make_simple_rail2 from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay @@ -77,7 +77,7 @@ def test_malfunction_process(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() @@ -131,7 +131,7 @@ def test_malfunction_process_statistically(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() @@ -178,7 +178,7 @@ def test_malfunction_before_entry(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() @@ -222,7 +222,7 @@ def test_malfunction_values_and_behavior(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() @@ -251,7 +251,7 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), + line_generator=random_line_generator(seed=10), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator @@ -316,7 +316,7 @@ def test_initial_malfunction_stop_moving(): rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) env.reset() @@ -400,7 +400,7 @@ def test_initial_malfunction_do_nothing(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator @@ -477,7 +477,7 @@ def tests_random_interference_from_outside(): # Set fixed malfunction duration for this test rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) + line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 env.reset(False, False, False, random_seed=10) @@ -501,7 +501,7 @@ def tests_random_interference_from_outside(): random.seed(47) np.random.seed(1234) env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) + line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 env.reset(False, False, False, random_seed=10) @@ -533,7 +533,7 @@ def test_last_malfunction_step(): rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) + line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 1. / 3. env.agents[0].target = (0, 0) diff --git a/tests/test_flatland_multiprocessing.py b/tests/test_flatland_multiprocessing.py index 23cfeeacdf160ce7a8389e4a1b6f42078650ade2..087c238e2ce5b4f764e262869f08f72798d7d824 100644 --- a/tests/test_flatland_multiprocessing.py +++ b/tests/test_flatland_multiprocessing.py @@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.simple_rail import make_simple_rail """Tests for `flatland` package.""" @@ -19,7 +19,7 @@ def test_multiprocessing_tree_obs(): obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=number_of_agents, + line_generator=random_line_generator(), number_of_agents=number_of_agents, obs_builder_object=obs_builder) env.reset(True, True) diff --git a/tests/test_flatland_schedule_from_file.py b/tests/test_flatland_schedule_from_file.py index 52a64a19343ad320d8c26922ed5e42e400821c73..0b903eae8b56864afc60d0d0d01923e337796479 100644 --- a/tests/test_flatland_schedule_from_file.py +++ b/tests/test_flatland_schedule_from_file.py @@ -3,11 +3,11 @@ from test_utils import create_and_save_env from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator, random_rail_generator, complex_rail_generator, \ rail_from_file -from flatland.envs.schedule_generators import sparse_schedule_generator, random_schedule_generator, \ - complex_schedule_generator, schedule_from_file +from flatland.envs.line_generators import sparse_line_generator, random_line_generator, \ + complex_line_generator, line_from_file -def test_schedule_from_file_sparse(): +def test_line_from_file_sparse(): """ Test to see that all parameters are loaded as expected Returns @@ -27,17 +27,17 @@ def test_schedule_from_file_sparse(): max_rails_between_cities=3, max_rails_in_city=6, ) - schedule_generator = sparse_schedule_generator(speed_ration_map) + line_generator = sparse_line_generator(speed_ration_map) create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator, - schedule_generator=schedule_generator) + line_generator=line_generator) # Sparse generator rail_generator = rail_from_file("./sparse_env_test.pkl") - schedule_generator = schedule_from_file("./sparse_env_test.pkl") + line_generator = line_from_file("./sparse_env_test.pkl") sparse_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, - schedule_generator=schedule_generator) + line_generator=line_generator) sparse_env_from_file.reset(True, True) # Assert loaded agent number is correct @@ -48,7 +48,7 @@ def test_schedule_from_file_sparse(): -def test_schedule_from_file_random(): +def test_line_from_file_random(): """ Test to see that all parameters are loaded as expected Returns @@ -63,17 +63,17 @@ def test_schedule_from_file_random(): # Generate random test env rail_generator = random_rail_generator() - schedule_generator = random_schedule_generator(speed_ration_map) + line_generator = random_line_generator(speed_ration_map) create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, - schedule_generator=schedule_generator) + line_generator=line_generator) # Random generator rail_generator = rail_from_file("./random_env_test.pkl") - schedule_generator = schedule_from_file("./random_env_test.pkl") + line_generator = line_from_file("./random_env_test.pkl") random_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, - schedule_generator=schedule_generator) + line_generator=line_generator) random_env_from_file.reset(True, True) # Assert loaded agent number is correct @@ -85,7 +85,7 @@ def test_schedule_from_file_random(): -def test_schedule_from_file_complex(): +def test_line_from_file_complex(): """ Test to see that all parameters are loaded as expected Returns @@ -103,19 +103,19 @@ def test_schedule_from_file_complex(): nr_extra=1, min_dist=8, max_dist=99999) - schedule_generator = complex_schedule_generator(speed_ration_map) + line_generator = complex_line_generator(speed_ration_map) create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, - schedule_generator=schedule_generator) + line_generator=line_generator) # Load the different envs and check the parameters # Complex generator rail_generator = rail_from_file("./complex_env_test.pkl") - schedule_generator = schedule_from_file("./complex_env_test.pkl") + line_generator = line_from_file("./complex_env_test.pkl") complex_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, - schedule_generator=schedule_generator) + line_generator=line_generator) complex_env_from_file.reset(True, True) # Assert loaded agent number is correct diff --git a/tests/test_generators.py b/tests/test_generators.py index c723c194f179efcc191f80fb93a3e5370e5469c9..61cf9523fb3f4e53dd29ac440960da8f866ad230 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -8,8 +8,8 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ random_rail_generator, empty_rail_generator -from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \ - schedule_from_file +from flatland.envs.line_generators import random_line_generator, complex_line_generator, \ + line_from_file from flatland.utils.simple_rail import make_simple_rail from flatland.envs.persistence import RailEnvPersister @@ -52,7 +52,7 @@ def test_complex_rail_generator(): # Check that agent number is changed to fit generated level env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - schedule_generator=complex_schedule_generator(), number_of_agents=n_agents) + line_generator=complex_line_generator(), number_of_agents=n_agents) env.reset() assert env.get_num_agents() == 2 assert env.rail.grid.shape == (y_dim, x_dim) @@ -62,7 +62,7 @@ def test_complex_rail_generator(): # Check that no agents are generated when level cannot be generated env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - schedule_generator=complex_schedule_generator(), number_of_agents=n_agents) + line_generator=complex_line_generator(), number_of_agents=n_agents) env.reset() assert env.get_num_agents() == 0 assert env.rail.grid.shape == (y_dim, x_dim) @@ -74,7 +74,7 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - schedule_generator=complex_schedule_generator(), number_of_agents=n_agents) + line_generator=complex_line_generator(), number_of_agents=n_agents) env.reset() assert env.get_num_agents() == n_agents assert env.rail.grid.shape == (y_dim, x_dim) @@ -84,7 +84,7 @@ def test_rail_from_grid_transition_map(): rail, rail_map = make_simple_rail() n_agents = 3 env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=n_agents) + line_generator=random_line_generator(), number_of_agents=n_agents) env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -106,7 +106,7 @@ def tests_rail_from_file(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=3, + line_generator=random_line_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() #env.save(file_name) @@ -116,7 +116,7 @@ def tests_rail_from_file(): agents_initial = env.agents env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), number_of_agents=1, + line_generator=line_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() rails_loaded = env.rail.grid @@ -134,7 +134,7 @@ def tests_rail_from_file(): file_name_2 = "test_without_distance_map.pkl" env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), + rail_generator=rail_from_grid_transition_map(rail), line_generator=random_line_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv()) env2.reset() #env2.save(file_name_2) @@ -144,7 +144,7 @@ def tests_rail_from_file(): agents_initial_2 = env2.agents env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, + line_generator=line_from_file(file_name_2), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) env2.reset() rails_loaded_2 = env2.rail.grid @@ -157,7 +157,7 @@ def tests_rail_from_file(): # Test to save with distance map and load without env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), number_of_agents=1, + line_generator=line_from_file(file_name), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) env3.reset() rails_loaded_3 = env3.rail.grid @@ -172,7 +172,7 @@ def tests_rail_from_file(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - schedule_generator=schedule_from_file(file_name_2), + line_generator=line_from_file(file_name_2), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), ) diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 0cd2ac1b5fb0af5583e66c931817ccd7ce8b7f71..5b090681813b9976285fe5a726ff9e7dc0226b08 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -4,7 +4,7 @@ from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.envs.line_generators import sparse_line_generator def test_get_global_observation(): @@ -26,7 +26,7 @@ def test_get_global_observation(): seed=15, grid_mode=False ), - schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, + line_generator=sparse_line_generator(speed_ration_map), number_of_agents=number_of_agents, obs_builder_object=GlobalObsForRailEnv()) env.reset() diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index 2593361e5922dd3078b614997e6306c1ab5549d5..917051bef982e1595b330a61ca1bd7b1b8f1935b 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -2,7 +2,7 @@ from flatland.envs.malfunction_generators import malfunction_from_params, malfun single_malfunction_generator, MalfunctionParameters from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.line_generators import random_line_generator from flatland.utils.simple_rail import make_simple_rail2 from flatland.envs.persistence import RailEnvPersister @@ -22,7 +22,7 @@ def test_malfanction_from_params(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) ) @@ -49,7 +49,7 @@ def test_malfanction_to_and_from_file(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) ) @@ -62,7 +62,7 @@ def test_malfanction_to_and_from_file(): env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) ) @@ -87,7 +87,7 @@ def test_single_malfunction_generator(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + line_generator=random_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=10, malfunction_duration=5) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 0467ce5ff34c12e98f854bc44a0f42caa4ee3649..08b46d00f4a458cfd105d568ca2a045f122cc8bc 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -5,7 +5,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map -from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator +from flatland.envs.line_generators import complex_line_generator, random_line_generator from flatland.utils.simple_rail import make_simple_rail from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay @@ -49,7 +49,7 @@ class RandomAgent: def test_multi_speed_init(): env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=1), schedule_generator=complex_schedule_generator(), + seed=1), line_generator=complex_line_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) @@ -94,7 +94,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -194,7 +194,7 @@ def test_multispeed_actions_no_malfunction_blocking(): """The second agent blocks the first because it is slower.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=2, + line_generator=random_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() set_penalties_for_replay(env) @@ -378,7 +378,7 @@ def test_multispeed_actions_malfunction_no_blocking(): """Test on a single agent whether action on cell exit work correctly despite malfunction.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -512,7 +512,7 @@ def test_multispeed_actions_no_malfunction_invalid_actions(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 17a658f0d939f07dd093f518939c8a6cde54a526..cda7c8c83a37b1ef4ad5750ad89427ace2baf7a0 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -4,7 +4,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map, sparse_rail_generator -from flatland.envs.schedule_generators import random_schedule_generator, sparse_schedule_generator +from flatland.envs.line_generators import random_line_generator, sparse_line_generator from flatland.utils.simple_rail import make_simple_rail2 @@ -15,7 +15,7 @@ def test_random_seeding(): # Move target to unreachable position in order to not interfere with test for idx in range(100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=12), number_of_agents=10) + line_generator=random_line_generator(seed=12), number_of_agents=10) env.reset(True, True, False, random_seed=1) env.agents[0].target = (0, 0) @@ -49,11 +49,11 @@ def test_seeding_and_observations(): # Make two seperate envs with different observation builders # Global Observation env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, + line_generator=random_line_generator(seed=12), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=12), number_of_agents=10, + line_generator=random_line_generator(seed=12), number_of_agents=10, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset(False, False, False, random_seed=12) @@ -107,12 +107,12 @@ def test_seeding_and_malfunction(): # Global Observation for tests in range(1, 100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=10, + line_generator=random_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=10, + line_generator=random_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) env.reset(True, False, True, random_seed=tests) @@ -172,7 +172,7 @@ def test_reproducability_env(): seed=215545, # Random seed grid_mode=True ), - schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1) + line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) env.reset(True, True, True, random_seed=10) excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -233,7 +233,7 @@ def test_reproducability_env(): seed=215545, # Random seed grid_mode=True ), - schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1) + line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) np.random.seed(10) for i in range(10): np.random.randn() diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index de8b13233a7cadd9c19331c959fc995b325de101..008992ee20fb626bea272d1cb70d73c456ff6cff 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -3,7 +3,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator +from flatland.envs.line_generators import speed_initialization_helper, complex_line_generator def test_speed_initialization_helper(): @@ -21,7 +21,7 @@ def test_rail_env_speed_intializer(): env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=1), schedule_generator=complex_schedule_generator(), + seed=1), line_generator=complex_line_generator(), number_of_agents=10) env.reset() actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 99f731e47d488d01f281acbdc2f556b92dbf0b6d..4e7c30ca72811b78c564d0c11a011dd4b9e04998 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,7 @@ from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.malfunction_generators import MalfunctionParameters, malfunction_from_params from flatland.envs.rail_env import RailEnvActions, RailEnv from flatland.envs.rail_generators import RailGenerator -from flatland.envs.schedule_generators import ScheduleGenerator +from flatland.envs.line_generators import LineGenerator from flatland.utils.rendertools import RenderTool from flatland.envs.persistence import RailEnvPersister @@ -136,7 +136,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, rewards_dict[a], replay.reward, 'reward') -def create_and_save_env(file_name: str, schedule_generator: ScheduleGenerator, rail_generator: RailGenerator): +def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator): stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence min_duration=15, # Minimal duration of malfunction max_duration=50 # Max duration of malfunction @@ -145,7 +145,7 @@ def create_and_save_env(file_name: str, schedule_generator: ScheduleGenerator, r env = RailEnv(width=30, height=30, rail_generator=rail_generator, - schedule_generator=schedule_generator, + line_generator=line_generator, number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), remove_agents_at_target=True)