diff --git a/tests/test_flatland_schedule_generators.py b/tests/test_flatland_schedule_generators.py index 487d619a347f7d4276f3e7e88a2986ac804669b0..915bf06d75f69105472d82acc91fe6bf156fc9c8 100644 --- a/tests/test_flatland_schedule_generators.py +++ b/tests/test_flatland_schedule_generators.py @@ -1,6 +1,6 @@ -from flatland.envs.rail_env import RailEnv from test_utils import create_and_save_env +from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator, random_rail_generator, complex_rail_generator, \ rail_from_file from flatland.envs.schedule_generators import sparse_schedule_generator, random_schedule_generator, \ @@ -14,6 +14,12 @@ def test_schedule_from_file(): ------- """ + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + # Generate Sparse test env rail_generator = sparse_rail_generator(max_num_cities=5, seed=1, @@ -21,30 +27,14 @@ def test_schedule_from_file(): max_rails_between_cities=3, max_rails_in_city=6, ) - - # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train - - schedule_generator = sparse_schedule_generator( - speed_ration_map) + schedule_generator = sparse_schedule_generator(speed_ration_map) create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator, schedule_generator=schedule_generator) # Generate random test env rail_generator = random_rail_generator() - - # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train - - schedule_generator = random_schedule_generator( - speed_ration_map) + schedule_generator = random_schedule_generator(speed_ration_map) create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, schedule_generator=schedule_generator) @@ -54,15 +44,7 @@ def test_schedule_from_file(): nr_extra=1, min_dist=8, max_dist=99999) - - # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train - - schedule_generator = complex_schedule_generator( - speed_ration_map) + schedule_generator = complex_schedule_generator(speed_ration_map) create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, schedule_generator=schedule_generator) @@ -72,7 +54,8 @@ def test_schedule_from_file(): # Sparse generator rail_generator = rail_from_file("./sparse_env_test.pkl") schedule_generator = schedule_from_file("./sparse_env_test.pkl") - sparse_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, schedule_generator=schedule_generator) + sparse_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, + schedule_generator=schedule_generator) sparse_env_from_file.reset(True, True) # Assert loaded agent number is correct @@ -84,7 +67,8 @@ def test_schedule_from_file(): # Random generator rail_generator = rail_from_file("./random_env_test.pkl") schedule_generator = schedule_from_file("./random_env_test.pkl") - random_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, schedule_generator=schedule_generator) + random_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, + schedule_generator=schedule_generator) random_env_from_file.reset(True, True) # Assert loaded agent number is correct @@ -96,7 +80,8 @@ def test_schedule_from_file(): # Complex generator rail_generator = rail_from_file("./complex_env_test.pkl") schedule_generator = schedule_from_file("./complex_env_test.pkl") - complex_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, schedule_generator=schedule_generator) + complex_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, + schedule_generator=schedule_generator) complex_env_from_file.reset(True, True) # Assert loaded agent number is correct @@ -104,4 +89,3 @@ def test_schedule_from_file(): # Assert max steps is correct assert complex_env_from_file._max_episode_steps == 1350 -