diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 77707b9f110376ddf2638b830830ff1a1c1edbf6..76095a2a2e1d9532951600118c6a777612641101 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -5,8 +5,9 @@ Collection of environment-specific PredictionBuilder.
 import numpy as np
 
 from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.rail_env_shortest_paths import get_shortest_paths
 from flatland.utils.ordered_set import OrderedSet
 
 
@@ -59,7 +60,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
 
                     continue
                 for action in action_priorities:
-                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
+                    cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
                         self.env._check_action_on_agent(action, agent)
                     if all([new_cell_isValid, transition_isValid]):
                         # move and change direction to face the new_direction that was
@@ -92,6 +93,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
+        Does not take into account future positions of other agents!
+
+        If there is no shortest path, the agent just stands still and stops moving.
 
         Parameters
         ----------
@@ -106,14 +110,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
             - position axis 0
             - position axis 1
             - direction
-            - action taken to come here
+            - action taken to come here (not implemented yet)
             The prediction at 0 is the current position, direction etc.
         """
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
-        distance_map = self.env.distance_map
-        assert distance_map is not None
+        distance_map: DistanceMap = self.env.distance_map
+
+        shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
 
         prediction_dict = {}
         for agent in agents:
@@ -123,52 +128,35 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
             times_per_cell = int(np.reciprocal(agent_speed))
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
+
+            shortest_path = shortest_paths[agent.handle]
+
+            # if there is a shortest path, remove the initial position
+            if shortest_path:
+                shortest_path = shortest_path[1:]
+
             new_direction = _agent_initial_direction
             new_position = _agent_initial_position
             visited = OrderedSet()
             for index in range(1, self.max_depth + 1):
-                # if we're at the target, stop moving...
-                if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
-                    continue
-                if not agent.moving:
-                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
+                # if we're at the target or not moving, stop moving until max_depth is reached
+                if new_position == agent.target or not agent.moving or not shortest_path:
+                    prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
+                    visited.add((*new_position, agent.direction))
                     continue
-                # Take shortest possible path
-                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-
-                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
-                    new_direction = np.argmax(cell_transitions)
-                    new_position = get_new_position(agent.position, new_direction)
-                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
-                    min_dist = np.inf
-                    no_dist_found = True
-                    for direction in range(4):
-                        if cell_transitions[direction] == 1:
-                            neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
-                            if target_dist < min_dist or no_dist_found:
-                                min_dist = target_dist
-                                new_direction = direction
-                                no_dist_found = False
-                    new_position = get_new_position(agent.position, new_direction)
-                elif index % times_per_cell == 0:
-                    raise Exception("No transition possible {}".format(cell_transitions))
-
-                # update the agent's position and direction
-                agent.position = new_position
-                agent.direction = new_direction
+
+                if index % times_per_cell == 0:
+                    new_position = shortest_path[0].position
+                    new_direction = shortest_path[0].direction
+
+                    shortest_path = shortest_path[1:]
 
                 # prediction is ready
                 prediction[index] = [index, *new_position, new_direction, 0]
-                visited.add((new_position[0], new_position[1], new_direction))
+                visited.add((*new_position, new_direction))
+
+            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
             self.env.dev_pred_dict[agent.handle] = visited
             prediction_dict[agent.handle] = prediction
 
-            # cleanup: reset initial position
-            agent.position = _agent_initial_position
-            agent.direction = _agent_initial_direction
-
         return prediction_dict
diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..793601d4d18ac38b729d15883089d5acbfc41ed3
--- /dev/null
+++ b/flatland/envs/rail_env_shortest_paths.py
@@ -0,0 +1,140 @@
+import math
+from typing import Dict, List, Optional, NamedTuple, Tuple, Set
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.distance_map import DistanceMap
+from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
+from flatland.utils.ordered_set import OrderedSet
+
+WalkingElement = \
+    NamedTuple('WalkingElement',
+               [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
+
+
+def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
+                            agent_position: Tuple[int, int],
+                            rail: GridTransitionMap) -> Set[RailEnvNextAction]:
+    """
+    Get the valid move actions (forward, left, right) for an agent.
+
+    Parameters
+    ----------
+    agent_direction : Grid4TransitionsEnum
+    agent_position: Tuple[int,int]
+    rail : GridTransitionMap
+
+
+    Returns
+    -------
+    Set of `RailEnvNextAction` (tuples of (action,position,direction))
+        Possible move actions (forward,left,right) and the next position/direction they lead to.
+        It is not checked that the next cell is free.
+    """
+    valid_actions: Set[RailEnvNextAction] = OrderedSet()
+    possible_transitions = rail.get_transitions(*agent_position, agent_direction)
+    num_transitions = np.count_nonzero(possible_transitions)
+    # Start from the current orientation, and see which transitions are available;
+    # organize them as [left, forward, right], relative to the current orientation
+    # If only one transition is possible, the forward branch is aligned with it.
+    if rail.is_dead_end(agent_position):
+        action = RailEnvActions.MOVE_FORWARD
+        exit_direction = (agent_direction + 2) % 4
+        if possible_transitions[exit_direction]:
+            new_position = get_new_position(agent_position, exit_direction)
+            valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
+    elif num_transitions == 1:
+        action = RailEnvActions.MOVE_FORWARD
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                new_position = get_new_position(agent_position, new_direction)
+                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
+    else:
+        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+            if possible_transitions[new_direction]:
+                if new_direction == agent_direction:
+                    action = RailEnvActions.MOVE_FORWARD
+                elif new_direction == (agent_direction + 1) % 4:
+                    action = RailEnvActions.MOVE_RIGHT
+                elif new_direction == (agent_direction - 1) % 4:
+                    action = RailEnvActions.MOVE_LEFT
+                else:
+                    raise Exception("Illegal state")
+
+                new_position = get_new_position(agent_position, new_direction)
+                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
+    return valid_actions
+
+
+# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
+def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None) \
+    -> Dict[int, Optional[List[WalkingElement]]]:
+    """
+    Computes the shortest path for each agent to its target and the action to be taken to do so.
+    The paths are derived from a `DistanceMap`.
+
+    If there is no path (rail disconnected), the path is given as None.
+    The agent state (moving or not) and its speed are not taken into account
+
+    Parameters
+    ----------
+    distance_map
+
+    Returns
+    -------
+        Dict[int, Optional[List[WalkingElement]]]
+
+    """
+    shortest_paths = dict()
+
+    def _shortest_path_for_agent(agent):
+        position = agent.position
+        direction = agent.direction
+        shortest_paths[agent.handle] = []
+        distance = math.inf
+        depth = 0
+        while (position != agent.target and (max_depth is None or depth < max_depth)):
+            next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
+            best_next_action = None
+            for next_action in next_actions:
+                next_action_distance = distance_map.get()[
+                    agent.handle, next_action.next_position[0], next_action.next_position[
+                        1], next_action.next_direction]
+                if next_action_distance < distance:
+                    best_next_action = next_action
+                    distance = next_action_distance
+
+            shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action))
+            depth += 1
+
+            # if there is no way to continue, the rail must be disconnected!
+            # (or distance map is incorrect)
+            if best_next_action is None:
+                shortest_paths[agent.handle] = None
+                return
+
+            position = best_next_action.next_position
+            direction = best_next_action.next_direction
+        if max_depth is None or depth < max_depth:
+            shortest_paths[agent.handle].append(
+                WalkingElement(position, direction,
+                               RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
+
+    for agent in distance_map.agents:
+        _shortest_path_for_agent(agent)
+
+    return shortest_paths
+
+
+def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
+    if agent_handle >= distance_map.get().shape[0]:
+        print("Error: agent_handle cannot be larger than actual number of agents")
+        return
+    # take min value of all 4 directions
+    min_distance_map = np.min(distance_map.get(), axis=3)
+    plt.imshow(min_distance_map[agent_handle][:][:])
+    plt.show()
diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index 0e305f48e14498382e1ad826b74e09cb469fee24..dc1cff12c0c8b1860859208a13d6403734a2d2ad 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -1,23 +1,8 @@
-import math
-from typing import Tuple, Set, Dict, List, NamedTuple
-
-import matplotlib.pyplot as plt
-import numpy as np
-
-from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-from flatland.envs.rail_env import RailEnv, RailEnvNextAction, RailEnvActions
+from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import rail_from_file
 from flatland.envs.schedule_generators import schedule_from_file
-from flatland.utils.ordered_set import OrderedSet
-
-WalkingElement = \
-    NamedTuple('WalkingElement',
-               [('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
 
 
 def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None):
@@ -32,98 +17,3 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b
                           schedule_generator=schedule_from_file(file_name, load_from_package),
                           obs_builder_object=obs_builder_object)
     return environment
-
-
-def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
-                            agent_position: Tuple[int, int],
-                            rail: GridTransitionMap) -> Set[RailEnvNextAction]:
-    """
-    Get the valid move actions (forward, left, right) for an agent.
-
-    Parameters
-    ----------
-    agent_direction : Grid4TransitionsEnum
-    agent_position: Tuple[int,int]
-    rail : GridTransitionMap
-
-
-    Returns
-    -------
-    Set of `RailEnvNextAction` (tuples of (action,position,direction))
-        Possible move actions (forward,left,right) and the next position/direction they lead to.
-        It is not checked that the next cell is free.
-    """
-    valid_actions: Set[RailEnvNextAction] = OrderedSet()
-    possible_transitions = rail.get_transitions(*agent_position, agent_direction)
-    num_transitions = np.count_nonzero(possible_transitions)
-    # Start from the current orientation, and see which transitions are available;
-    # organize them as [left, forward, right], relative to the current orientation
-    # If only one transition is possible, the forward branch is aligned with it.
-    if rail.is_dead_end(agent_position):
-        action = RailEnvActions.MOVE_FORWARD
-        exit_direction = (agent_direction + 2) % 4
-        if possible_transitions[exit_direction]:
-            new_position = get_new_position(agent_position, exit_direction)
-            valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
-    elif num_transitions == 1:
-        action = RailEnvActions.MOVE_FORWARD
-        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
-            if possible_transitions[new_direction]:
-                new_position = get_new_position(agent_position, new_direction)
-                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
-    else:
-        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
-            if possible_transitions[new_direction]:
-                if new_direction == agent_direction:
-                    action = RailEnvActions.MOVE_FORWARD
-                elif new_direction == (agent_direction + 1) % 4:
-                    action = RailEnvActions.MOVE_RIGHT
-                elif new_direction == (agent_direction - 1) % 4:
-                    action = RailEnvActions.MOVE_LEFT
-                else:
-                    raise Exception("Illegal state")
-
-                new_position = get_new_position(agent_position, new_direction)
-                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
-    return valid_actions
-
-
-def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[WalkingElement]]:
-    # TODO: do we need to support unreachable targets?
-    # TODO refactoring: unify with predictor (support agent.moving and max_depth)
-    shortest_paths = dict()
-    for a in distance_map.agents:
-        position = a.position
-        direction = a.direction
-        shortest_paths[a.handle] = []
-        distance = math.inf
-        while (position != a.target):
-            next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
-            best_next_action = None
-            for next_action in next_actions:
-                next_action_distance = distance_map.get()[
-                    a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction]
-                if next_action_distance < distance:
-                    best_next_action = next_action
-                    distance = next_action_distance
-
-            shortest_paths[a.handle].append(WalkingElement(position, direction, best_next_action))
-
-            position = best_next_action.next_position
-            direction = best_next_action.next_direction
-
-        shortest_paths[a.handle].append(
-            WalkingElement(position, direction,
-                           RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
-
-    return shortest_paths
-
-
-def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
-    if agent_handle >= distance_map.get().shape[0]:
-        print("Error: agent_handle cannot be larger than actual number of agents")
-        return
-    # take min value of all 4 directions
-    min_distance_map = np.min(distance_map.get(), axis=3)
-    plt.imshow(min_distance_map[agent_handle][:][:])
-    plt.show()
diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py
index 6da29d7f6d1a52c42dd006b84f94a959990e0932..a12c26e66fdf5a9ff102bb79440bd4f4b805e819 100644
--- a/flatland/utils/simple_rail.py
+++ b/flatland/utils/simple_rail.py
@@ -45,6 +45,46 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     return rail, rail_map
 
 
+def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _    _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [dead_end_from_west]  + [dead_end_from_east] + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
+
+
 def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 569cd3addddf77e0e328e7d244ec57078a90281f..7ee0fd4aadf72b12b591259a71af8b408145418f 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
-from flatland.envs.rail_env_utils import get_shortest_paths, WalkingElement
+from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
@@ -236,12 +236,13 @@ def test_shortest_path_predictor(rendering=False):
         [20.],
     ])
 
+    assert np.array_equal(time_offsets, expected_time_offsets), \
+        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
+
     assert np.array_equal(positions, expected_positions), \
         "positions {}, expected {}".format(positions, expected_positions)
     assert np.array_equal(directions, expected_directions), \
         "directions {}, expected {}".format(directions, expected_directions)
-    assert np.array_equal(time_offsets, expected_time_offsets), \
-        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
 
 
 def test_shortest_path_predictor_conflicts(rendering=False):
diff --git a/tests/test_shortest_path.py b/tests/test_flatland_envs_rail_env_shortest_paths.py.py
similarity index 78%
rename from tests/test_shortest_path.py
rename to tests/test_flatland_envs_rail_env_shortest_paths.py.py
index 094b5603c062229aeca2049652721448c59bef5e..4600c4a3002995e1238a0ccbda762501ac985408 100644
--- a/tests/test_shortest_path.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py.py
@@ -1,7 +1,41 @@
 import numpy as np
 
-from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
-from flatland.envs.rail_env_utils import load_flatland_environment_from_file, get_shortest_paths, WalkingElement
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv
+from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
+from flatland.envs.rail_env_utils import load_flatland_environment_from_file
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.simple_rail import make_disconnected_simple_rail
+
+
+def test_get_shortest_paths_unreachable():
+    rail, rail_map = make_disconnected_simple_rail()
+
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
+                  )
+
+    # set the initial position
+    agent = env.agents_static[0]
+    agent.position = (3, 1)  # west dead-end
+    agent.direction = Grid4TransitionsEnum.WEST
+    agent.target = (3, 9)  # east dead-end
+    agent.moving = True
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
+    actual = get_shortest_paths(env.distance_map)
+    expected = {0: None}
+
+    assert actual == expected, "actual={},expected={}".format(actual, expected)
 
 
 def test_get_shortest_paths():
@@ -130,3 +164,31 @@ def test_get_shortest_paths():
     for agent_handle in expected:
         assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
             "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
+
+
+def test_get_shortest_paths_max_depth():
+    env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
+    actual = get_shortest_paths(env.distance_map, max_depth=2)
+
+    expected = {
+        0: [
+            WalkingElement(position=(1, 1), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(1, 2), next_direction=1)),
+            WalkingElement(position=(1, 2), direction=1,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(1, 3), next_direction=1))
+        ],
+        1: [
+            WalkingElement(position=(3, 18), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(3, 17), next_direction=3)),
+            WalkingElement(position=(3, 17), direction=3,
+                           next_action_element=RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD,
+                                                                 next_position=(3, 16), next_direction=3)),
+        ]
+    }
+
+    for agent_handle in expected:
+        assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
+            "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])