diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index c9595b7693497daa7db110b8fc8b4ae040d39cc9..ee2c263711906370900c14a6030d57ada573c2ea 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -7,6 +7,8 @@ a GridTransitionMap object.
 
 import numpy as np
 
+from flatland.core.transitions import Grid4TransitionsEnum
+
 
 def get_direction(pos1, pos2):
     """
@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2):
 
 
 def get_new_position(position, movement):
-    if movement == 0:  # NORTH
+    """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
+    if movement == Grid4TransitionsEnum.NORTH:
         return (position[0] - 1, position[1])
-    elif movement == 1:  # EAST
+    elif movement == Grid4TransitionsEnum.EAST:
         return (position[0], position[1] + 1)
-    elif movement == 2:  # SOUTH
+    elif movement == Grid4TransitionsEnum.SOUTH:
         return (position[0] + 1, position[1])
-    elif movement == 3:  # WEST
+    elif movement == Grid4TransitionsEnum.WEST:
         return (position[0], position[1] - 1)
 
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index d7fdcee7cf0f1183f9430c08684e4ede468d188a..a7f91f1439f98bc2627700903b0175486f619749 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -6,6 +6,7 @@ from collections import deque
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.transitions import Grid4TransitionsEnum
 from flatland.envs.env_utils import coordinate_to_position
 
 
@@ -162,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
         Utility function that converts a compass movement over a 2D grid to new positions (r, c).
         """
-        if movement == 0:  # NORTH
+        if movement == Grid4TransitionsEnum.NORTH:
             return (position[0] - 1, position[1])
-        elif movement == 1:  # EAST
+        elif movement == Grid4TransitionsEnum.EAST:
             return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
+        elif movement == Grid4TransitionsEnum.SOUTH:
             return (position[0] + 1, position[1])
-        elif movement == 3:  # WEST
+        elif movement == Grid4TransitionsEnum.WEST:
             return (position[0], position[1] - 1)
 
     def get_many(self, handles=[]):
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 3910fa1b6a3deb5055841af4381957a0a474c3a7..e1b90b8ac4cdaa1784f9e056c5c4c45ad10fb1ff 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
 import numpy as np
 
 from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.envs.env_utils import get_new_position
 from flatland.envs.rail_env import RailEnvActions
 
 
@@ -55,8 +56,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
                 action_done = False
                 # if we're at the target, stop moving...
                 if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction,
-                                         RailEnvActions.STOP_MOVING]
+                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
 
                     continue
                 for action in action_priorities:
@@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
 
                 if np.sum(cell_transitions) == 1:
                     new_direction = np.argmax(cell_transitions)
-                    new_position = self._new_position(agent.position, new_direction)
+                    new_position = get_new_position(agent.position, new_direction)
                 elif np.sum(cell_transitions) > 1:
                     min_dist = np.inf
                     for direction in range(4):
@@ -144,11 +144,22 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                             if target_dist < min_dist:
                                 min_dist = target_dist
                                 new_direction = direction
-                    new_position = self._new_position(agent.position, new_direction)
+                    new_position = get_new_position(agent.position, new_direction)
+                else:
+                    raise Exception("No transition possible {}".format(cell_transitions))
+
 
+                action = None
+                for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
+                    cell_isFree, new_cell_isValid, new_direction, _new_position, transition_isValid = \
+                        self.env._check_action_on_agent(action, agent)
+                    if np.array_equal(_new_position, new_position):
+                        action = _action
+                        break
+                assert action is not None
                 agent.position = new_position
                 agent.direction = new_direction
-                prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD]
+                prediction[index] = [index, *new_position, new_direction, action]
                 action_done = True
                 if not action_done:
                     raise Exception("Cannot move further. Something is wrong")
@@ -159,16 +170,3 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
             agent.direction = _agent_initial_direction
 
         return prediction_dict
-
-    def _new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == 0:  # NORTH
-            return (position[0] - 1, position[1])
-        elif movement == 1:  # EAST
-            return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
-            return (position[0] + 1, position[1])
-        elif movement == 3:  # WEST
-            return (position[0], position[1] - 1)