diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index 02d1a3af83bbee4a65f66259dbfb6c7c18b25dba..b23347da08d5ddb12ad700ce8b38ce75e2a1eefa 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -98,7 +98,6 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath
         distance = math.inf
         while (position != a.target):
             next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
-
             best_next_action = None
             for next_action in next_actions:
                 next_action_distance = distance_map.get()[
@@ -111,5 +110,9 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[ShortestPath
 
             position = best_next_action.next_position
             direction = best_next_action.next_direction
+        if position == a.target:
+            shortest_paths[a.handle].append(
+                ShortestPathElement(position, direction,
+                                    RailEnvNextAction(RailEnvActions.DO_NOTHING, position, direction)))
 
-    return shortest_paths
+        return shortest_paths
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 2df60017bc5add7e90d1bbd3b07fd24299dbcc11..d891ccaac76b6810ab55732c836ab8a4bdc7ba03 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
-from flatland.envs.rail_env_utils import get_shortest_paths
+from flatland.envs.rail_env_utils import get_shortest_paths, ShortestPathElement
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
@@ -145,11 +145,18 @@ def test_shortest_path_predictor(rendering=False):
 
     paths = get_shortest_paths(env.distance_map)[0]
     assert paths == [
-        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0),
-        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0),
-        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), next_direction=1),
-        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), next_direction=1),
-        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), next_direction=1)]
+        ShortestPathElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6),
+                                                         next_direction=0)),
+        ShortestPathElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6),
+                                                         next_direction=0)),
+        ShortestPathElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7),
+                                                         next_direction=1)),
+        ShortestPathElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8),
+                                                         next_direction=1)),
+        ShortestPathElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9),
+                                                         next_direction=1)),
+        ShortestPathElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.DO_NOTHING, next_position=(3, 9),
+                                                         next_direction=1))]
 
     # extract the data
     predictions = env.obs_builder.predictions