diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index ef28266d6bf39446b8fb2c1f2e179e626e9eae3f..b7e1d56a529d5d05aa8c56e92afb439f1d73a401 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -6,7 +6,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths -#from flatland.envs.rail_env_utils import load_flatland_environment_from_file +from flatland.envs.rail_env_utils import load_flatland_environment_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.schedule_generators import random_schedule_generator @@ -42,10 +42,32 @@ def test_get_shortest_paths_unreachable(): # todo file test_002.pkl has to be generated automatically # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 def test_get_shortest_paths(): - #env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') - env, env_dict = RailEnvPersister.load_new("./env_data/tests/test_002.mpk") + #env = load_flatland_environment_from_file('test_002.mpk', 'env_data.tests') + env, env_dict = RailEnvPersister.load_new("test_002.mpk", "env_data.tests") + + #print("env len(agents): ", len(env.agents)) + #print(env.distance_map) + #print("env number_of_agents:", env.number_of_agents) + + #print("env agents:", env.agents) + + #env.distance_map.reset(env.agents, env.rail) + + #actual = get_shortest_paths(env.distance_map) + #print("shortest paths:", actual) + + #print(env.distance_map) + #print("Dist map agents:", env.distance_map.agents) + + #print("\nenv reset()") env.reset() actual = get_shortest_paths(env.distance_map) + #print("env agents: ", len(env.agents)) + #print("env number_of_agents: ", env.number_of_agents) + + + + assert len(actual) == 2, "get_shortest_paths should return a dict of length 2" expected = { 0: [ @@ -99,7 +121,7 @@ def test_get_shortest_paths(): # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 def test_get_shortest_paths_max_depth(): #env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') - env, _ = RailEnvPersister.load_new("./env_data/tests/test_002.mpk") + env, _ = RailEnvPersister.load_new("test_002.mpk", "env_data.tests") env.reset() actual = get_shortest_paths(env.distance_map, max_depth=2) @@ -288,3 +310,9 @@ def test_get_k_shortest_paths(rendering=False): ]) assert actual == expected, "actual={},expected={}".format(actual, expected) + +def main(): + test_get_shortest_paths() + +if __name__ == "__main__": + main()