diff --git a/examples/training_example.py b/examples/training_example.py
index 8f7d18f55ec63f2ed2d57584817d14f8bc722259..c038e7b477069957efdec622b2c56e9e84cb7ac0 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -10,7 +10,6 @@ np.random.seed(1)
 
 # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
 # Training on simple small tasks is the best way to get familiar with the environment
-#
 
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 88a79ea72606e7c1b46f92f6d73429979d67a4e6..4718ad9906db9b479123b53e9e9df0ff4db3b462 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         for agent in agents:
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
+            agent_speed = agent.speed_data["speed"]
+            times_per_cell = int(np.reciprocal(agent_speed))
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
+            new_direction = _agent_initial_direction
+            new_position = _agent_initial_position
             visited = set()
             for index in range(1, self.max_depth + 1):
                 # if we're at the target, stop moving...
@@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                 # Take shortest possible path
                 cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
 
-                new_position = None
-                new_direction = None
-                if np.sum(cell_transitions) == 1:
+                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
                     new_direction = np.argmax(cell_transitions)
                     new_position = get_new_position(agent.position, new_direction)
-                elif np.sum(cell_transitions) > 1:
+                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
                     min_dist = np.inf
                     no_dist_found = True
                     for direction in range(4):
@@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                                 new_direction = direction
                                 no_dist_found = False
                     new_position = get_new_position(agent.position, new_direction)
-                else:
+                elif index % times_per_cell == 0:
                     raise Exception("No transition possible {}".format(cell_transitions))
 
                 # update the agent's position and direction