diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 00dabd312bb6bc21b7b7fbdd84608df2d8ee1ec7..b90d38a427533846baadb704f5137c90c1044f73 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -109,7 +109,11 @@ class EnvAgent: return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle] def get_travel_time_on_shortest_path(self, distance_map) -> int: - distance = len(self.get_shortest_path(distance_map)) + shortest_path = self.get_shortest_path(distance_map) + if shortest_path is not None: + distance = len(shortest_path) + else: + distance = 0 speed = self.speed_data['speed'] return int(np.ceil(distance / speed)) diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index bc4b169b1aad3893d96d82cb8284b369f13104f2..b8354078f68705e2d3757a61626c03df2ac9dc4c 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -21,7 +21,7 @@ from flatland.envs.distance_map import DistanceMap # cannot import objects / classes directly because of circular import from flatland.envs import malfunction_generators as mal_gen from flatland.envs import rail_generators as rail_gen -from flatland.envs import schedule_generators as sched_gen +from flatland.envs import line_generators as line_gen msgpack_numpy.patch() @@ -122,7 +122,7 @@ class RailEnvPersister(object): width=width, height=height, rail_generator=rail_gen.rail_from_file(filename, load_from_package=load_from_package), - schedule_generator=sched_gen.schedule_from_file(filename, + line_generator=line_gen.line_from_file(filename, load_from_package=load_from_package), #malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename, # load_from_package=load_from_package), diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 815ecbcd7c1e091681b67d5c095b890ae6ab798a..d5c95408c21bf4f427dbbae929e36cc40ee79065 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -34,25 +34,29 @@ def test_action_plan(rendering: bool = False): for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) - chosen_path_dict = {0: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), - TrainrunWaypoint(lined_at=2, waypoint=Waypoint(position=(3, 1), direction=1)), - TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 2), direction=1)), - TrainrunWaypoint(lined_at=14, waypoint=Waypoint(position=(3, 3), direction=1)), - TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(3, 4), direction=1)), - TrainrunWaypoint(lined_at=16, waypoint=Waypoint(position=(3, 5), direction=1)), - TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(3, 6), direction=1)), - TrainrunWaypoint(lined_at=18, waypoint=Waypoint(position=(3, 7), direction=1)), - TrainrunWaypoint(lined_at=19, waypoint=Waypoint(position=(3, 8), direction=1)), - TrainrunWaypoint(lined_at=20, waypoint=Waypoint(position=(3, 8), direction=5))], - 1: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 8), direction=3)), - TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 7), direction=3)), - TrainrunWaypoint(lined_at=5, waypoint=Waypoint(position=(3, 6), direction=3)), - TrainrunWaypoint(lined_at=7, waypoint=Waypoint(position=(3, 5), direction=3)), - TrainrunWaypoint(lined_at=9, waypoint=Waypoint(position=(3, 4), direction=3)), - TrainrunWaypoint(lined_at=11, waypoint=Waypoint(position=(3, 3), direction=3)), - TrainrunWaypoint(lined_at=13, waypoint=Waypoint(position=(2, 3), direction=0)), - TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(1, 3), direction=0)), - TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]} + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + + chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), + TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)), + TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)), + TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)), + TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)), + TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)), + TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)), + TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)), + TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)), + TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))], + 1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)), + TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)), + TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)), + TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)), + TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)), + TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)), + TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)), + TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)), + TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]} expected_action_plan = [[ # take action to enter the grid ActionPlanElement(0, RailEnvActions.MOVE_FORWARD), diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 4f507415e8acc25a1d931b68f249d568839fd609..87b373051c30e90fb3665a78d8d8e5a46a468d91 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -17,6 +17,11 @@ def test_initial_status(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) env.reset() + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ @@ -126,6 +131,10 @@ def test_status_done_remove(): remove_agents_at_target=True) env.reset() + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index c531359a2d3d573db32cdec7515155511e5f6d57..53a61c8f755f813d10f892f1a60f17d8650a89e7 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -9,9 +9,9 @@ from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_generators import complex_rail_generator, rail_from_file +from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.line_generators import random_line_generator, complex_line_generator, line_from_file +from flatland.envs.line_generators import random_line_generator, sparse_line_generator, line_from_file from flatland.utils.simple_rail import make_simple_rail from flatland.envs.persistence import RailEnvPersister from flatland.utils.rendertools import RenderTool @@ -37,9 +37,10 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, - rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), - line_generator=complex_line_generator(), number_of_agents=2) + rail_generator=sparse_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), + line_generator=sparse_line_generator(), number_of_agents=2) env.reset() + agent_1_pos = env.agents[0].position agent_1_dir = env.agents[0].direction agent_1_tar = env.agents[0].target diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index b2ee9b015c74f74c651aa6396dee3c6ea5a49ce4..ce88aeb81d06875768e903eb63582b2a79160e91 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -23,6 +23,10 @@ def test_get_shortest_paths_unreachable(): obs_builder_object=GlobalObsForRailEnv()) env.reset() + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + # set the initial position agent = env.agents[0] agent.position = (3, 1) # west dead-end @@ -36,7 +40,7 @@ def test_get_shortest_paths_unreachable(): actual = get_shortest_paths(env.distance_map) expected = {0: None} - assert actual == expected, "actual={},expected={}".format(actual, expected) + assert actual[0] == expected[0], "actual={},expected={}".format(actual[0], expected[0]) # todo file test_002.pkl has to be generated automatically diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 3e74d720d4276dae588a3989843ada9d1761671d..74e71daced5cde123f7b25054b264ebeee816888 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1512,7 +1512,7 @@ def test_sparse_generator_changes_to_grid_mode(): rail_env = RailEnv(width=10, height=20, rail_generator=sparse_rail_generator( max_num_cities=100, max_rails_between_cities=2, - max_rails_in_city=2, + max_rail_pairs_in_city=1, seed=15, grid_mode=False ), line_generator=sparse_line_generator(), number_of_agents=10, diff --git a/tests/test_flatland_schedule_from_file.py b/tests/test_flatland_line_from_file.py similarity index 79% rename from tests/test_flatland_schedule_from_file.py rename to tests/test_flatland_line_from_file.py index 0b903eae8b56864afc60d0d0d01923e337796479..b324af98a8550a47621be6258d2d559f128f45a5 100644 --- a/tests/test_flatland_schedule_from_file.py +++ b/tests/test_flatland_line_from_file.py @@ -25,12 +25,14 @@ def test_line_from_file_sparse(): seed=1, grid_mode=False, max_rails_between_cities=3, - max_rails_in_city=6, + max_rail_pairs_in_city=3, ) line_generator = sparse_line_generator(speed_ration_map) - create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator, + env = create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator, line_generator=line_generator) + old_num_steps = env._max_episode_steps + old_num_agents = len(env.agents) # Sparse generator @@ -41,10 +43,10 @@ def test_line_from_file_sparse(): sparse_env_from_file.reset(True, True) # Assert loaded agent number is correct - assert sparse_env_from_file.get_num_agents() == 10 + assert sparse_env_from_file.get_num_agents() == old_num_agents # Assert max steps is correct - assert sparse_env_from_file._max_episode_steps == 500 + assert sparse_env_from_file._max_episode_steps == old_num_steps @@ -65,8 +67,10 @@ def test_line_from_file_random(): rail_generator = random_rail_generator() line_generator = random_line_generator(speed_ration_map) - create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, + env = create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, line_generator=line_generator) + old_num_steps = env._max_episode_steps + old_num_agents = len(env.agents) # Random generator @@ -77,10 +81,10 @@ def test_line_from_file_random(): random_env_from_file.reset(True, True) # Assert loaded agent number is correct - assert random_env_from_file.get_num_agents() == 10 + assert random_env_from_file.get_num_agents() == old_num_agents # Assert max steps is correct - assert random_env_from_file._max_episode_steps == 1350 + assert random_env_from_file._max_episode_steps == old_num_steps @@ -105,8 +109,10 @@ def test_line_from_file_complex(): max_dist=99999) line_generator = complex_line_generator(speed_ration_map) - create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, + env = create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, line_generator=line_generator) + old_num_steps = env._max_episode_steps + old_num_agents = len(env.agents) # Load the different envs and check the parameters @@ -119,7 +125,7 @@ def test_line_from_file_complex(): complex_env_from_file.reset(True, True) # Assert loaded agent number is correct - assert complex_env_from_file.get_num_agents() == 10 + assert complex_env_from_file.get_num_agents() == old_num_agents # Assert max steps is correct - assert complex_env_from_file._max_episode_steps == 1350 + assert complex_env_from_file._max_episode_steps == old_num_steps diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 53915102c539c8e7703434100634c36fda001049..8675f54fb68bc4203a0a5e3fd9799791fa4ca539 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -97,6 +97,8 @@ def test_malfunction_process(): actions[i] = np.argmax(obs[i]) + 1 obs, all_rewards, done, _ = env.step(actions) + if done["__all__"]: + break if env.agents[0].malfunction_data['malfunction'] > 0: agent_malfunctioning = True diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 5b090681813b9976285fe5a726ff9e7dc0226b08..851d849d1246773d7d06b5f38ed0eef820f74a56 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -30,6 +30,10 @@ def test_get_global_observation(): obs_builder_object=GlobalObsForRailEnv()) env.reset() + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) for i in range(len(env.agents)): agent: EnvAgent = env.agents[i] diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 08b46d00f4a458cfd105d568ca2a045f122cc8bc..2664c5b4f8b18004b4a39095c16acee753832eb7 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -51,6 +51,7 @@ def test_multi_speed_init(): rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=1), line_generator=complex_line_generator(), number_of_agents=5) + # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) @@ -197,6 +198,12 @@ def test_multispeed_actions_no_malfunction_blocking(): line_generator=random_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + + set_penalties_for_replay(env) test_configs = [ ReplayConfig( @@ -381,7 +388,11 @@ def test_multispeed_actions_malfunction_no_blocking(): line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() - + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ @@ -515,6 +526,10 @@ def test_multispeed_actions_no_malfunction_invalid_actions(): line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents set_penalties_for_replay(env) test_config = ReplayConfig( diff --git a/tests/test_utils.py b/tests/test_utils.py index 4e7c30ca72811b78c564d0c11a011dd4b9e04998..062d56f00dd704960b316e318ee311f5c7a03539 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -152,4 +152,4 @@ def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_gene env.reset(True, True) #env.save(file_name) RailEnvPersister.save(env, file_name) - + return env