diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index e19464b77cf09980616a88384542858625d138bb..b325b920155ebf6433963de8f2c71a45af92547d 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -372,13 +372,16 @@ def get_k_shortest_paths(env: RailEnv, direction=el.direction, next_action_element=RailEnvNextAction( action=int(get_action_for_move(el.position, - el.direction, - p[i + 1].position, - p[i + 1].direction, - env.rail)), + el.direction, + p[i + 1].position, + p[i + 1].direction, + env.rail)), next_position=p[i + 1].position, next_direction=p[i + 1].direction)) for i, el in enumerate(p[:-1])) - shortest_paths_with_action.append(p_with_action) + target_walking_element = WalkingElement(position=p[-1].position, direction=p[-1].direction, + next_action_element=RailEnvNextAction(action=int(RailEnvActions.DO_NOTHING), + next_position=None, next_direction=None)) + shortest_paths_with_action.append(p_with_action + (target_walking_element,)) # return P return shortest_paths_with_action diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index d69e203e94e4ef8a7b68c0c7fe44a4c1f9fde42b..bfcfb6cbfc0a3a6042c70aa72d11101bcc283df3 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -436,7 +436,9 @@ def test_get_k_shortest_paths(rendering=False): WalkingElement(position=(1, 9), direction=2, next_action_element=RailEnvNextAction(action=2, next_position=(2, 9), next_direction=2)), WalkingElement(position=(2, 9), direction=2, - next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=2)) + next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=2)), + WalkingElement(position=(3, 9), direction=2, + next_action_element=RailEnvNextAction(action=0, next_position=None, next_direction=None)) ), ( WalkingElement(position=(3, 1), direction=3, @@ -470,7 +472,9 @@ def test_get_k_shortest_paths(rendering=False): WalkingElement(position=(4, 8), direction=1, next_action_element=RailEnvNextAction(action=2, next_position=(4, 9), next_direction=1)), WalkingElement(position=(4, 9), direction=1, - next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=0)) + next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=0)), + WalkingElement(position=(3, 9), direction=0, + next_action_element=RailEnvNextAction(action=0, next_position=None, next_direction=None)) ) ])