diff --git a/env_data/tests/test_002.pkl b/env_data/tests/test_002.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..e46a5c088ef555109c352b7a4d5eaf8dfbaf8700
Binary files /dev/null and b/env_data/tests/test_002.pkl differ
diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
index c9c6b00375ef4577880e2b8c98c2ff9dc946a7fa..22721407f059ff2e02d907110cb9f982aa9d599e 100644
--- a/flatland/envs/distance_map.py
+++ b/flatland/envs/distance_map.py
@@ -18,27 +18,21 @@ class DistanceMap:
         self.agents: List[EnvAgent] = agents
         self.rail: Optional[GridTransitionMap] = None
 
-    """
-    Set the distance map
-    """
     def set(self, distance_map: np.ndarray):
+        """
+        Set the distance map
+        """
         self.distance_map = distance_map
 
-    """
-    Get the distance map
-    """
     def get(self) -> np.ndarray:
-
+        """
+        Get the distance map
+        """
         if self.reset_was_called:
             self.reset_was_called = False
 
             nb_agents = len(self.agents)
             compute_distance_map = True
-            if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation):
-                compute_distance_map = False
-                for i in range(nb_agents):
-                    if self.agents[i].target != self.agents_previous_computation[i].target:
-                        compute_distance_map = True
             # Don't compute the distance map if it was loaded
             if self.agents_previous_computation is None and self.distance_map is not None:
                 compute_distance_map = False
@@ -51,12 +45,12 @@ class DistanceMap:
 
         return self.distance_map
 
-    """
-    Reset the distance map
-    """
     def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        """
+        Reset the distance map
+        """
         self.reset_was_called = True
-        self.agents = agents
+        self.agents: List[EnvAgent] = agents
         self.rail = rail
         self.env_height = rail.height
         self.env_width = rail.width
@@ -110,7 +104,8 @@ class DistanceMap:
 
         return max_distance
 
-    def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance, enforce_target_direction=-1):
+    def _get_and_update_neighbors(self, rail: GridTransitionMap, position, target_nr, current_distance,
+                                  enforce_target_direction=-1):
         """
         Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
         minimum distances from each target cell.
@@ -134,8 +129,7 @@ class DistanceMap:
                 for agent_orientation in range(4):
                     # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
                     is_valid = rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                            desired_movement_from_new_cell)
-                    # is_valid = True
+                                                   desired_movement_from_new_cell)
 
                     if is_valid:
                         """
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 77707b9f110376ddf2638b830830ff1a1c1edbf6..76095a2a2e1d9532951600118c6a777612641101 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -5,8 +5,9 @@ Collection of environment-specific PredictionBuilder.
 import numpy as np
 
 from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.rail_env_shortest_paths import get_shortest_paths
 from flatland.utils.ordered_set import OrderedSet
 
 
@@ -59,7 +60,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
 
                     continue
                 for action in action_priorities:
-                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
+                    cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
                         self.env._check_action_on_agent(action, agent)
                     if all([new_cell_isValid, transition_isValid]):
                         # move and change direction to face the new_direction that was
@@ -92,6 +93,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
+        Does not take into account future positions of other agents!
+
+        If there is no shortest path, the agent just stands still and stops moving.
 
         Parameters
         ----------
@@ -106,14 +110,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
             - position axis 0
             - position axis 1
             - direction
-            - action taken to come here
+            - action taken to come here (not implemented yet)
             The prediction at 0 is the current position, direction etc.
         """
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
-        distance_map = self.env.distance_map
-        assert distance_map is not None
+        distance_map: DistanceMap = self.env.distance_map
+
+        shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
 
         prediction_dict = {}
         for agent in agents:
@@ -123,52 +128,35 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
             times_per_cell = int(np.reciprocal(agent_speed))
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
+
+            shortest_path = shortest_paths[agent.handle]
+
+            # if there is a shortest path, remove the initial position
+            if shortest_path:
+                shortest_path = shortest_path[1:]
+
             new_direction = _agent_initial_direction
             new_position = _agent_initial_position
             visited = OrderedSet()
             for index in range(1, self.max_depth + 1):
-                # if we're at the target, stop moving...
-                if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
-                    continue
-                if not agent.moving:
-                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
+                # if we're at the target or not moving, stop moving until max_depth is reached
+                if new_position == agent.target or not agent.moving or not shortest_path:
+                    prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
+                    visited.add((*new_position, agent.direction))
                     continue
-                # Take shortest possible path
-                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-
-                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
-                    new_direction = np.argmax(cell_transitions)
-                    new_position = get_new_position(agent.position, new_direction)
-                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
-                    min_dist = np.inf
-                    no_dist_found = True
-                    for direction in range(4):
-                        if cell_transitions[direction] == 1:
-                            neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
-                            if target_dist < min_dist or no_dist_found:
-                                min_dist = target_dist
-                                new_direction = direction
-                                no_dist_found = False
-                    new_position = get_new_position(agent.position, new_direction)
-                elif index % times_per_cell == 0:
-                    raise Exception("No transition possible {}".format(cell_transitions))
-
-                # update the agent's position and direction
-                agent.position = new_position
-                agent.direction = new_direction
+
+                if index % times_per_cell == 0:
+                    new_position = shortest_path[0].position
+                    new_direction = shortest_path[0].direction
+
+                    shortest_path = shortest_path[1:]
 
                 # prediction is ready
                 prediction[index] = [index, *new_position, new_direction, 0]
-                visited.add((new_position[0], new_position[1], new_direction))
+                visited.add((*new_position, new_direction))
+
+            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
             self.env.dev_pred_dict[agent.handle] = visited
             prediction_dict[agent.handle] = prediction
 
-            # cleanup: reset initial position
-            agent.position = _agent_initial_position
-            agent.direction = _agent_initial_direction
-
         return prediction_dict
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index d0add3086014c7ad07c29e01588cba380c26cbd7..b5bc44f2e698a4ffa23ca31d34ac14f613340d04 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -4,7 +4,7 @@ Definition of the RailEnv environment.
 # TODO:  _ this is a global method --> utils or remove later
 import warnings
 from enum import IntEnum
-from typing import List, Set, NamedTuple, Optional, Tuple, Dict
+from typing import List, NamedTuple, Optional, Tuple, Dict
 
 import msgpack
 import msgpack_numpy as m
@@ -20,7 +20,6 @@ from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
-from flatland.utils.ordered_set import OrderedSet
 
 m.patch()
 
@@ -241,9 +240,6 @@ class RailEnv(Environment):
         #  can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
-        if optionals and 'distance_map' in optionals:
-            self.distance_map.set(optionals['distance_map'])
-
         if regen_rail or self.rail is None:
             self.rail = rail
             self.height, self.width = self.rail.grid.shape
@@ -253,6 +249,11 @@ class RailEnv(Environment):
                     check = self.rail.cell_neighbours_valid(rc_pos, True)
                     if not check:
                         warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
+        # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
+        #  hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by
+        #  rail_from_file!!!
+        elif optionals and 'distance_map' in optionals:
+            self.distance_map.set(optionals['distance_map'])
 
         if replace_agents:
             agents_hints = None
@@ -587,60 +588,6 @@ class RailEnv(Environment):
             transition_valid = True
         return new_direction, transition_valid
 
-    @staticmethod
-    def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
-                                agent_position: Tuple[int, int],
-                                rail: GridTransitionMap) -> Set[RailEnvNextAction]:
-        """
-        Get the valid move actions (forward, left, right) for an agent.
-
-        Parameters
-        ----------
-        agent_direction : Grid4TransitionsEnum
-        agent_position: Tuple[int,int]
-        rail : GridTransitionMap
-
-
-        Returns
-        -------
-        Set of `RailEnvNextAction` (tuples of (action,position,direction))
-            Possible move actions (forward,left,right) and the next position/direction they lead to.
-            It is not checked that the next cell is free.
-        """
-        valid_actions: Set[RailEnvNextAction] = OrderedSet()
-        possible_transitions = rail.get_transitions(*agent_position, agent_direction)
-        num_transitions = np.count_nonzero(possible_transitions)
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right], relative to the current orientation
-        # If only one transition is possible, the forward branch is aligned with it.
-        if rail.is_dead_end(agent_position):
-            action = RailEnvActions.MOVE_FORWARD
-            exit_direction = (agent_direction + 2) % 4
-            if possible_transitions[exit_direction]:
-                new_position = get_new_position(agent_position, exit_direction)
-                valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
-        elif num_transitions == 1:
-            action = RailEnvActions.MOVE_FORWARD
-            for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
-                if possible_transitions[new_direction]:
-                    new_position = get_new_position(agent_position, new_direction)
-                    valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
-        else:
-            for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
-                if possible_transitions[new_direction]:
-                    if new_direction == agent_direction:
-                        action = RailEnvActions.MOVE_FORWARD
-                    elif new_direction == (agent_direction + 1) % 4:
-                        action = RailEnvActions.MOVE_RIGHT
-                    elif new_direction == (agent_direction - 1) % 4:
-                        action = RailEnvActions.MOVE_LEFT
-                    else:
-                        raise Exception("Illegal state")
-
-                    new_position = get_new_position(agent_position, new_direction)
-                    valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
-        return valid_actions
-
     def _get_observations(self):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..793601d4d18ac38b729d15883089d5acbfc41ed3
--- /dev/null
+++ b/flatland/envs/rail_env_shortest_paths.py
@@ -0,0 +1,140 @@
+import math
+from typing import Dict, List, Optional, NamedTuple, Tuple, Set
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.distance_map import DistanceMap
+from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
+from flatland.utils.ordered_set import OrderedSet
+
+WalkingElement = \
+    NamedTuple('WalkingElement',
+               [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
+
+
+def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
+                            agent_position: Tuple[int, int],
+                            rail: GridTransitionMap) -> Set[RailEnvNextAction]:
+    """
+    Get the valid move actions (forward, left, right) for an agent.
+
+    Parameters
+    ----------
+    agent_direction : Grid4TransitionsEnum
+    agent_position: Tuple[int,int]
+    rail : GridTransitionMap
+
+
+    Returns
+    -------
+    Set of `RailEnvNextAction` (tuples of (action,position,direction))
+        Possible move actions (forward,left,right) and the next position/direction they lead to.
+        It is not checked that the next cell is free.
+    """
+    valid_actions: Set[RailEnvNextAction] = OrderedSet()
+    possible_transitions = rail.get_transitions(*agent_position, agent_direction)
+    num_transitions = np.count_nonzero(possible_transitions)
+    # Start from the current orientation, and see which transitions are available;
+    # organize them as [left, forward, right], relative to the current orientation
+    # If only one transition is possible, the forward branch is aligned with it.
+    if rail.is_dead_end(agent_position):
+        action = RailEnvActions.MOVE_FORWARD
+        exit_direction = (agent_direction + 2) % 4
+        if possible_transitions[exit_direction]:
+            new_position = get_new_position(agent_position, exit_direction)
+            valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
+    elif num_transitions == 1:
+        action = RailEnvActions.MOVE_FORWARD
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                new_position = get_new_position(agent_position, new_direction)
+                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
+    else:
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                if new_direction == agent_direction:
+                    action = RailEnvActions.MOVE_FORWARD
+                elif new_direction == (agent_direction + 1) % 4:
+                    action = RailEnvActions.MOVE_RIGHT
+                elif new_direction == (agent_direction - 1) % 4:
+                    action = RailEnvActions.MOVE_LEFT
+                else:
+                    raise Exception("Illegal state")
+
+                new_position = get_new_position(agent_position, new_direction)
+                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
+    return valid_actions
+
+
+# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
+def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None) \
+    -> Dict[int, Optional[List[WalkingElement]]]:
+    """
+    Computes the shortest path for each agent to its target and the action to be taken to do so.
+    The paths are derived from a `DistanceMap`.
+
+    If there is no path (rail disconnected), the path is given as None.
+    The agent state (moving or not) and its speed are not taken into account
+
+    Parameters
+    ----------
+    distance_map
+
+    Returns
+    -------
+        Dict[int, Optional[List[WalkingElement]]]
+
+    """
+    shortest_paths = dict()
+
+    def _shortest_path_for_agent(agent):
+        position = agent.position
+        direction = agent.direction
+        shortest_paths[agent.handle] = []
+        distance = math.inf
+        depth = 0
+        while (position != agent.target and (max_depth is None or depth < max_depth)):
+            next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
+            best_next_action = None
+            for next_action in next_actions:
+                next_action_distance = distance_map.get()[
+                    agent.handle, next_action.next_position[0], next_action.next_position[
+                        1], next_action.next_direction]
+                if next_action_distance < distance:
+                    best_next_action = next_action
+                    distance = next_action_distance
+
+            shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action))
+            depth += 1
+
+            # if there is no way to continue, the rail must be disconnected!
+            # (or distance map is incorrect)
+            if best_next_action is None:
+                shortest_paths[agent.handle] = None
+                return
+
+            position = best_next_action.next_position
+            direction = best_next_action.next_direction
+        if max_depth is None or depth < max_depth:
+            shortest_paths[agent.handle].append(
+                WalkingElement(position, direction,
+                               RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
+
+    for agent in distance_map.agents:
+        _shortest_path_for_agent(agent)
+
+    return shortest_paths
+
+
+def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
+    if agent_handle >= distance_map.get().shape[0]:
+        print("Error: agent_handle cannot be larger than actual number of agents")
+        return
+    # take min value of all 4 directions
+    min_distance_map = np.min(distance_map.get(), axis=3)
+    plt.imshow(min_distance_map[agent_handle][:][:])
+    plt.show()
diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index 69cfce764fe124d9e3eb05019e1c734f2285bdb5..dc1cff12c0c8b1860859208a13d6403734a2d2ad 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -1,7 +1,3 @@
-import numpy as np
-import matplotlib.pyplot as plt
-
-from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -21,13 +17,3 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b
                           schedule_generator=schedule_from_file(file_name, load_from_package),
                           obs_builder_object=obs_builder_object)
     return environment
-
-
-def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
-    if agent_handle >= distance_map.get().shape[0]:
-        print("Error: agent_handle cannot be larger than actual number of agents")
-        return
-    # take min value of all 4 directions
-    min_distance_map = np.min(distance_map.get(), axis=3)
-    plt.imshow(min_distance_map[agent_handle][:][:])
-    plt.show()
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 92a0f84f35fa942b03236c6add6e722475a2d842..4dad2ca872725517ffd47308f088f098b6abe1aa 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -174,7 +174,6 @@ class PILGL(GraphicsLayer):
         self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
 
     def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
-        print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer)
         xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
         self.text(*xyPixLeftTop, strText, layer)
 
@@ -606,7 +605,6 @@ class PILSVG(PILGL):
             self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER)
 
         if show_debug:
-            print("Call text:")
             self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
 
     def set_cell_occupied(self, agent_idx, row, col):
diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py
index 6da29d7f6d1a52c42dd006b84f94a959990e0932..a12c26e66fdf5a9ff102bb79440bd4f4b805e819 100644
--- a/flatland/utils/simple_rail.py
+++ b/flatland/utils/simple_rail.py
@@ -45,6 +45,46 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     return rail, rail_map
 
 
+def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _    _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [dead_end_from_west]  + [dead_end_from_east] + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
+
+
 def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index c31494673e63a17dc07eb6d89eeb581c640b1e13..7ee0fd4aadf72b12b591259a71af8b408145418f 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -7,7 +7,8 @@ import numpy as np
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
-from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
+from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
@@ -142,6 +143,21 @@ def test_shortest_path_predictor(rendering=False):
         1], agent.direction] == 5.0, "found {} instead of {}".format(
         distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
 
+    paths = get_shortest_paths(env.distance_map)[0]
+    assert paths == [
+        WalkingElement((5, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6),
+                                                    next_direction=0)),
+        WalkingElement((4, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6),
+                                                    next_direction=0)),
+        WalkingElement((3, 6), 0, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7),
+                                                    next_direction=1)),
+        WalkingElement((3, 7), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8),
+                                                    next_direction=1)),
+        WalkingElement((3, 8), 1, RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9),
+                                                    next_direction=1)),
+        WalkingElement((3, 9), 1, RailEnvNextAction(action=RailEnvActions.STOP_MOVING, next_position=(3, 9),
+                                                    next_direction=1))]
+
     # extract the data
     predictions = env.obs_builder.predictions
     positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
@@ -220,12 +236,13 @@ def test_shortest_path_predictor(rendering=False):
         [20.],
     ])
 
+    assert np.array_equal(time_offsets, expected_time_offsets), \
+        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
+
     assert np.array_equal(positions, expected_positions), \
         "positions {}, expected {}".format(positions, expected_positions)
     assert np.array_equal(directions, expected_directions), \
         "directions {}, expected {}".format(directions, expected_directions)
-    assert np.array_equal(time_offsets, expected_time_offsets), \
-        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
 
 
 def test_shortest_path_predictor_conflicts(rendering=False):
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..4600c4a3002995e1238a0ccbda762501ac985408
--- /dev/null
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -0,0 +1,194 @@
+import numpy as np
+
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv
+from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
+from flatland.envs.rail_env_utils import load_flatland_environment_from_file
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.simple_rail import make_disconnected_simple_rail
+
+
+def test_get_shortest_paths_unreachable():
+    rail, rail_map = make_disconnected_simple_rail()
+
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
+                  )
+
+    # set the initial position
+    agent = env.agents_static[0]
+    agent.position = (3, 1)  # west dead-end
+    agent.direction = Grid4TransitionsEnum.WEST
+    agent.target = (3, 9)  # east dead-end
+    agent.moving = True
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
+    actual = get_shortest_paths(env.distance_map)
+    expected = {0: None}
+
+    assert actual == expected, "actual={},expected={}".format(actual, expected)
+
+
+def test_get_shortest_paths():
+    env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
+    actual = get_shortest_paths(env.distance_map)
+
+    expected = {
+        0: [
+            WalkingElement(position=(1, 1), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(1, 2), next_direction=1)),
+            WalkingElement(position=(1, 2), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(1, 3), next_direction=1)),
+            WalkingElement(position=(1, 3), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 3), next_direction=2)),
+            WalkingElement(position=(2, 3), direction=2,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 4), next_direction=1)),
+            WalkingElement(position=(2, 4), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 5), next_direction=1)),
+            WalkingElement(position=(2, 5), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 6), next_direction=1)),
+            WalkingElement(position=(2, 6), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 7), next_direction=1)),
+            WalkingElement(position=(2, 7), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 8), next_direction=1)),
+            WalkingElement(position=(2, 8), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 9), next_direction=1)),
+            WalkingElement(position=(2, 9), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 10), next_direction=1)),
+            WalkingElement(position=(2, 10), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 11), next_direction=1)),
+            WalkingElement(position=(2, 11), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 12), next_direction=1)),
+            WalkingElement(position=(2, 12), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 13), next_direction=1)),
+            WalkingElement(position=(2, 13), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 14), next_direction=1)),
+            WalkingElement(position=(2, 14), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 15), next_direction=1)),
+            WalkingElement(position=(2, 15), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 16), next_direction=1)),
+            WalkingElement(position=(2, 16), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 17), next_direction=1)),
+            WalkingElement(position=(2, 17), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 18), next_direction=1)),
+            WalkingElement(position=(2, 18), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.STOP_MOVING,
+                                                                 next_position=(2, 18), next_direction=1))],
+        1: [
+            WalkingElement(position=(3, 18), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(3, 17), next_direction=3)),
+            WalkingElement(position=(3, 17), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(3, 16), next_direction=3)),
+            WalkingElement(position=(3, 16), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 16), next_direction=0)),
+            WalkingElement(position=(2, 16), direction=0,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 15), next_direction=3)),
+            WalkingElement(position=(2, 15), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 14), next_direction=3)),
+            WalkingElement(position=(2, 14), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 13), next_direction=3)),
+            WalkingElement(position=(2, 13), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 12), next_direction=3)),
+            WalkingElement(position=(2, 12), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 11), next_direction=3)),
+            WalkingElement(position=(2, 11), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 10), next_direction=3)),
+            WalkingElement(position=(2, 10), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 9), next_direction=3)),
+            WalkingElement(position=(2, 9), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 8), next_direction=3)),
+            WalkingElement(position=(2, 8), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 7), next_direction=3)),
+            WalkingElement(position=(2, 7), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 6), next_direction=3)),
+            WalkingElement(position=(2, 6), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 5), next_direction=3)),
+            WalkingElement(position=(2, 5), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 4), next_direction=3)),
+            WalkingElement(position=(2, 4), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 3), next_direction=3)),
+            WalkingElement(position=(2, 3), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 2), next_direction=3)),
+            WalkingElement(position=(2, 2), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(2, 1), next_direction=3)),
+            WalkingElement(position=(2, 1), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.STOP_MOVING,
+                                                                 next_position=(2, 1), next_direction=3))]
+    }
+
+    for agent_handle in expected:
+        assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
+            "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
+
+
+def test_get_shortest_paths_max_depth():
+    env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
+    actual = get_shortest_paths(env.distance_map, max_depth=2)
+
+    expected = {
+        0: [
+            WalkingElement(position=(1, 1), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(1, 2), next_direction=1)),
+            WalkingElement(position=(1, 2), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(1, 3), next_direction=1))
+        ],
+        1: [
+            WalkingElement(position=(3, 18), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(3, 17), next_direction=3)),
+            WalkingElement(position=(3, 17), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(3, 16), next_direction=3)),
+        ]
+    }
+
+    for agent_handle in expected:
+        assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
+            "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])