diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py
index c945ad7b779e1108945161920cb819f4e10f25e9..e19464b77cf09980616a88384542858625d138bb 100644
--- a/flatland/envs/rail_env_shortest_paths.py
+++ b/flatland/envs/rail_env_shortest_paths.py
@@ -71,6 +71,128 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
     return valid_actions
 
 
+def get_new_position_for_action(
+    agent_position: Tuple[int, int],
+    agent_direction: Grid4TransitionsEnum,
+    action: RailEnvNextAction,
+    rail: GridTransitionMap) -> Tuple[int, int, int]:
+    """
+    Get the next position for this action.
+
+    Parameters
+    ----------
+    agent_position
+    agent_direction
+    action
+    rail
+
+
+    Returns
+    -------
+    Tuple[int,int,int]
+        row, column, direction
+    """
+    possible_transitions = rail.get_transitions(*agent_position, agent_direction)
+    num_transitions = np.count_nonzero(possible_transitions)
+    # Start from the current orientation, and see which transitions are available;
+    # organize them as [left, forward, right], relative to the current orientation
+    # If only one transition is possible, the forward branch is aligned with it.
+    if rail.is_dead_end(agent_position):
+        valid_action = RailEnvActions.MOVE_FORWARD
+        exit_direction = (agent_direction + 2) % 4
+        if possible_transitions[exit_direction]:
+            new_position = get_new_position(agent_position, exit_direction)
+            if valid_action == action:
+                return new_position, exit_direction
+    elif num_transitions == 1:
+        valid_action = RailEnvActions.MOVE_FORWARD
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                new_position = get_new_position(agent_position, new_direction)
+                if valid_action == action:
+                    return new_position, new_direction
+    else:
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                if new_direction == agent_direction:
+                    valid_action = RailEnvActions.MOVE_FORWARD
+                    if valid_action == action:
+                        new_position = get_new_position(agent_position, new_direction)
+                        return new_position, new_direction
+                elif new_direction == (agent_direction + 1) % 4:
+                    valid_action = RailEnvActions.MOVE_RIGHT
+                    if valid_action == action:
+                        new_position = get_new_position(agent_position, new_direction)
+                        return new_position, new_direction
+                elif new_direction == (agent_direction - 1) % 4:
+                    valid_action = RailEnvActions.MOVE_LEFT
+                    if valid_action == action:
+                        new_position = get_new_position(agent_position, new_direction)
+                        return new_position, new_direction
+
+
+def get_action_for_move(
+    agent_position: Tuple[int, int],
+    agent_direction: Grid4TransitionsEnum,
+    next_agent_position: Tuple[int, int],
+    next_agent_direction: int,
+    rail: GridTransitionMap) -> Optional[RailEnvActions]:
+    """
+    Get the action (if any) to move from a position and direction to another.
+
+    Parameters
+    ----------
+    agent_position
+    agent_direction
+    next_agent_position
+    next_agent_direction
+    rail
+
+
+    Returns
+    -------
+    Optional[RailEnvActions]
+        the action (if direct transition possible) or None.
+    """
+    possible_transitions = rail.get_transitions(*agent_position, agent_direction)
+    num_transitions = np.count_nonzero(possible_transitions)
+    # Start from the current orientation, and see which transitions are available;
+    # organize them as [left, forward, right], relative to the current orientation
+    # If only one transition is possible, the forward branch is aligned with it.
+    if rail.is_dead_end(agent_position):
+        valid_action = RailEnvActions.MOVE_FORWARD
+        new_direction = (agent_direction + 2) % 4
+        if possible_transitions[new_direction]:
+            new_position = get_new_position(agent_position, new_direction)
+            if new_position == next_agent_position and new_direction == next_agent_direction:
+                return valid_action
+    elif num_transitions == 1:
+        valid_action = RailEnvActions.MOVE_FORWARD
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                new_position = get_new_position(agent_position, new_direction)
+                if new_position == next_agent_position and new_direction == next_agent_direction:
+                    return valid_action
+    else:
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                if new_direction == agent_direction:
+                    valid_action = RailEnvActions.MOVE_FORWARD
+                    new_position = get_new_position(agent_position, new_direction)
+                    if new_position == next_agent_position and new_direction == next_agent_direction:
+                        return valid_action
+                elif new_direction == (agent_direction + 1) % 4:
+                    valid_action = RailEnvActions.MOVE_RIGHT
+                    new_position = get_new_position(agent_position, new_direction)
+                    if new_position == next_agent_position and new_direction == next_agent_direction:
+                        return valid_action
+                elif new_direction == (agent_direction - 1) % 4:
+                    valid_action = RailEnvActions.MOVE_LEFT
+                    new_position = get_new_position(agent_position, new_direction)
+                    if new_position == next_agent_position and new_direction == next_agent_direction:
+                        return valid_action
+
+
 # N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
 def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) \
     -> Dict[int, Optional[List[WalkingElement]]]:
@@ -157,6 +279,7 @@ def get_k_shortest_paths(env: RailEnv,
     Computes the k shortest paths using modified Dijkstra
     following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing
     In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths.
+    We add the next_action_element
 
     Parameters
     ----------
@@ -173,7 +296,6 @@ def get_k_shortest_paths(env: RailEnv,
     -------
     List[Tuple[WalkingElement]]
         We use tuples since we need the path elements to be hashable.
-        The walking elements do not contain any actions.
         We use a list of paths in order to keep the order of length.
     """
 
@@ -215,7 +337,8 @@ def get_k_shortest_paths(env: RailEnv,
 
         # – if u = t then P = P U {Pu}
         if u.position == target_position:
-            print(" found of length {} {}".format(len(pu), pu))
+            if debug:
+                print(" found of length {} {}".format(len(pu), pu))
             shortest_paths.append(pu)
 
         # – if countu ≤ K then
@@ -242,9 +365,23 @@ def get_k_shortest_paths(env: RailEnv,
                     pv = pu + (v,)
                     #     – insert Pv into B
                     heap.add(pv)
+    # add actions to shortest paths
+    shortest_paths_with_action = []
+    for p in shortest_paths:
+        p_with_action = tuple(WalkingElement(position=el.position,
+                                             direction=el.direction,
+                                             next_action_element=RailEnvNextAction(
+                                                 action=int(get_action_for_move(el.position,
+                                                                            el.direction,
+                                                                            p[i + 1].position,
+                                                                            p[i + 1].direction,
+                                                                            env.rail)),
+                                                 next_position=p[i + 1].position,
+                                                 next_direction=p[i + 1].direction)) for i, el in enumerate(p[:-1]))
+        shortest_paths_with_action.append(p_with_action)
 
     # return P
-    return shortest_paths
+    return shortest_paths_with_action
 
 
 def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
index 395f05e885cdb80198df47d2d27e241099a0e4a8..d69e203e94e4ef8a7b68c0c7fe44a4c1f9fde42b 100644
--- a/tests/test_flatland_envs_rail_env_shortest_paths.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -395,49 +395,83 @@ def test_get_k_shortest_paths(rendering=False):
         renderer.render_env(show=True, show_observations=False)
         input()
 
-    actual = get_k_shortest_paths(
+    actual = set(get_k_shortest_paths(
         env=env,
         source_position=initial_position,  # west dead-end
         source_direction=int(initial_direction),  # east
         target_position=target_position,
         k=10
-    )
+    ))
 
-    expected = [
-        (WalkingElement(position=(3, 1), direction=3, next_action_element=None),
-         WalkingElement(position=(3, 0), direction=3, next_action_element=None),
-         WalkingElement(position=(3, 1), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 2), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 3), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 4), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 5), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 6), direction=1, next_action_element=None),
-         WalkingElement(position=(4, 6), direction=2, next_action_element=None),
-         WalkingElement(position=(5, 6), direction=2, next_action_element=None),
-         WalkingElement(position=(6, 6), direction=2, next_action_element=None),
-         WalkingElement(position=(5, 6), direction=0, next_action_element=None),
-         WalkingElement(position=(4, 6), direction=0, next_action_element=None),
-         WalkingElement(position=(4, 7), direction=1, next_action_element=None),
-         WalkingElement(position=(4, 8), direction=1, next_action_element=None),
-         WalkingElement(position=(4, 9), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 9), direction=0, next_action_element=None)),
-        (WalkingElement(position=(3, 1), direction=3, next_action_element=None),
-         WalkingElement(position=(3, 0), direction=3, next_action_element=None),
-         WalkingElement(position=(3, 1), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 2), direction=1, next_action_element=None),
-         WalkingElement(position=(3, 3), direction=1, next_action_element=None),
-         WalkingElement(position=(2, 3), direction=0, next_action_element=None),
-         WalkingElement(position=(1, 3), direction=0, next_action_element=None),
-         WalkingElement(position=(0, 3), direction=0, next_action_element=None),
-         WalkingElement(position=(0, 4), direction=1, next_action_element=None),
-         WalkingElement(position=(0, 5), direction=1, next_action_element=None),
-         WalkingElement(position=(0, 6), direction=1, next_action_element=None),
-         WalkingElement(position=(0, 7), direction=1, next_action_element=None),
-         WalkingElement(position=(0, 8), direction=1, next_action_element=None),
-         WalkingElement(position=(0, 9), direction=1, next_action_element=None),
-         WalkingElement(position=(1, 9), direction=2, next_action_element=None),
-         WalkingElement(position=(2, 9), direction=2, next_action_element=None),
-         WalkingElement(position=(3, 9), direction=2, next_action_element=None))
-    ]
+    expected = set([
+        (
+            WalkingElement(position=(3, 1), direction=3,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 0), next_direction=3)),
+            WalkingElement(position=(3, 0), direction=3,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 1), next_direction=1)),
+            WalkingElement(position=(3, 1), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 2), next_direction=1)),
+            WalkingElement(position=(3, 2), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 3), next_direction=1)),
+            WalkingElement(position=(3, 3), direction=1,
+                           next_action_element=RailEnvNextAction(action=1, next_position=(2, 3), next_direction=0)),
+            WalkingElement(position=(2, 3), direction=0,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(1, 3), next_direction=0)),
+            WalkingElement(position=(1, 3), direction=0,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(0, 3), next_direction=0)),
+            WalkingElement(position=(0, 3), direction=0,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(0, 4), next_direction=1)),
+            WalkingElement(position=(0, 4), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(0, 5), next_direction=1)),
+            WalkingElement(position=(0, 5), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(0, 6), next_direction=1)),
+            WalkingElement(position=(0, 6), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(0, 7), next_direction=1)),
+            WalkingElement(position=(0, 7), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(0, 8), next_direction=1)),
+            WalkingElement(position=(0, 8), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(0, 9), next_direction=1)),
+            WalkingElement(position=(0, 9), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(1, 9), next_direction=2)),
+            WalkingElement(position=(1, 9), direction=2,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(2, 9), next_direction=2)),
+            WalkingElement(position=(2, 9), direction=2,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=2))
+        ),
+        (
+            WalkingElement(position=(3, 1), direction=3,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 0), next_direction=3)),
+            WalkingElement(position=(3, 0), direction=3,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 1), next_direction=1)),
+            WalkingElement(position=(3, 1), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 2), next_direction=1)),
+            WalkingElement(position=(3, 2), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 3), next_direction=1)),
+            WalkingElement(position=(3, 3), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 4), next_direction=1)),
+            WalkingElement(position=(3, 4), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 5), next_direction=1)),
+            WalkingElement(position=(3, 5), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 6), next_direction=1)),
+            WalkingElement(position=(3, 6), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(4, 6), next_direction=2)),
+            WalkingElement(position=(4, 6), direction=2,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(5, 6), next_direction=2)),
+            WalkingElement(position=(5, 6), direction=2,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(6, 6), next_direction=2)),
+            WalkingElement(position=(6, 6), direction=2,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(5, 6), next_direction=0)),
+            WalkingElement(position=(5, 6), direction=0,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(4, 6), next_direction=0)),
+            WalkingElement(position=(4, 6), direction=0,
+                           next_action_element=RailEnvNextAction(action=3, next_position=(4, 7), next_direction=1)),
+            WalkingElement(position=(4, 7), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(4, 8), next_direction=1)),
+            WalkingElement(position=(4, 8), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(4, 9), next_direction=1)),
+            WalkingElement(position=(4, 9), direction=1,
+                           next_action_element=RailEnvNextAction(action=2, next_position=(3, 9), next_direction=0))
+        )
+    ])
 
     assert actual == expected, "actual={},expected={}".format(actual, expected)