diff --git a/examples/training_example.py b/examples/training_example.py
index c785804001e96b3474794804b7e63ef78a450f01..78336b96e2b4e9ed7da1be3253d83c769a9a21f4 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -2,7 +2,7 @@ import numpy as np
 
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 
 np.random.seed(1)
@@ -11,7 +11,7 @@ np.random.seed(1)
 # Training on simple small tasks is the best way to get familiar with the environment
 #
 
-TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv())
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4fa0e09e35a202088cba67752325a387e316bd53..885a966b87e23e6294b090fa7c9338ff3893cb60 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -177,7 +177,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         if self.predictor:
             self.predicted_pos = {}
             self.predicted_dir = {}
-            self.predictions = self.predictor.get()
+            self.predictions = self.predictor.get(self.distance_map)
 
             for t in range(len(self.predictions[0])):
                 pos_list = []
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index f7dec074e6dbf43694ddf5284675cebb10ea5b59..988594e63ad72b895127934258be19c1de85d0b6 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -16,7 +16,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, handle=None):
+    def get(self, distancemap, handle=None):
         """
         Called whenever predict is called on the environment.
 
@@ -72,3 +72,92 @@ class DummyPredictorForRailEnv(PredictionBuilder):
             agent.position = _agent_initial_position
             agent.direction = _agent_initial_direction
         return prediction_dict
+
+
+class ShortestPathPredictorForRailEnv(PredictionBuilder):
+    """
+    DummyPredictorForRailEnv object.
+
+    This object returns predictions for agents in the RailEnv environment.
+    The prediction acts as if no other agent is in the environment and always takes the forward action.
+    """
+
+    def get(self, distancemap, handle=None):
+        """
+        Called whenever predict is called on the environment.
+
+        Parameters
+        -------
+        handle : int (optional)
+            Handle of the agent for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
+            - time_offset
+            - position axis 0
+            - position axis 1
+            - direction
+            - action taken to come here
+        """
+        agents = self.env.agents
+        if handle:
+            agents = [self.env.agents[handle]]
+
+        prediction_dict = {}
+        agent_idx = 0
+        for agent in agents:
+            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
+            _agent_initial_position = agent.position
+            _agent_initial_direction = agent.direction
+            prediction = np.zeros(shape=(self.max_depth + 1, 5))
+            prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
+            for index in range(1, self.max_depth + 1):
+                action_done = False
+                # if we're at the target, stop moving...
+                if agent.position == agent.target:
+                    prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
+                                         RailEnvActions.STOP_MOVING]
+
+                    continue
+                # Take shortest possible path
+                cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
+
+                if np.sum(cell_transitions) == 1:
+                    new_direction = np.argmax(cell_transitions)
+                    new_position = self._new_position(agent.position, new_direction)
+                else:
+                    for direct in range(4):
+                        min_dist = np.inf
+                        target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
+                        if target_dist < min_dist:
+                            min_dist = target_dist
+                            new_direction = direct
+                            new_position = self._new_position(agent.position, new_direction)
+
+                agent.position = new_position
+                agent.direction = new_direction
+                prediction[index] = [index, new_position[0], new_position[1], new_direction, 0]
+                action_done = True
+                if not action_done:
+                    raise Exception("Cannot move further. Something is wrong")
+            prediction_dict[agent.handle] = prediction
+            agent.position = _agent_initial_position
+            agent.direction = _agent_initial_direction
+            agent_idx += 1
+
+        return prediction_dict
+
+    def _new_position(self, position, movement):
+        """
+        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
+        """
+        if movement == 0:  # NORTH
+            return (position[0] - 1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0] + 1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index cb7d26df75017b8a531ff6552eb6c1651c80a08f..5f5cea35ccac6f09f186bb84cfdb4122b11de4c4 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -74,7 +74,7 @@ def test_predictions():
     env.agents[0].direction = 0
     env.agents[0].target = (3., 0.)
 
-    predictions = env.obs_builder.predictor.get()
+    predictions = env.obs_builder.predictor.get(None)
     positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))