diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 3c35b30a3be4e0c1563b09c1c86359295226079e..723bb1102092c7d48bd938bbd60d0c5213ffecf6 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -1,10 +1,13 @@
 import random
 import time
+
 import numpy as np
 
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid_utils import coordinate_to_position
 from flatland.envs.generators import random_rail_generator, complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -86,20 +89,11 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 
 env = RailEnv(width=7,
               height=7,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
-              number_of_agents=2,
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
+              number_of_agents=1,
               obs_builder_object=SingleAgentNavigationObs())
 
-obs, all_rewards, done, _ = env.step({0: 0, 1: 1})
-for i in range(env.get_num_agents()):
-    print(obs[i])
-
-env = RailEnv(width=50,
-            height=50,
-            rail_generator=random_rail_generator(),
-            number_of_agents=1,
-            obs_builder_object=SingleAgentNavigationObs())
-obs, all_rewards, done, _ = env.step({0: 0})
+obs = env.reset()
 env_renderer = RenderTool(env, gl="PILSVG")
 env_renderer.render_env(show=True, frames=True, show_observations=True)
 for step in range(100):
@@ -108,5 +102,119 @@ for step in range(100):
     print("Rewards: ", all_rewards, "  [done=", done, "]")
     env_renderer.render_env(show=True, frames=True, show_observations=True)
     time.sleep(0.1)
+    if done["__all__"]:
+        break
+env_renderer.close_window()
+
+
+class ObservePredictions(TreeObsForRailEnv):
+    """
+    We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
+
+    We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
+    the minimum distances from each grid node to each agent's target.
 
+    This is necessary so that we can pass the distance map to the ShortestPathPredictor
 
+    Here we also want to highlight how you can visualize your observation
+    """
+
+    def __init__(self, predictor):
+        super().__init__(max_depth=0)
+        self.observation_space = [10]
+        self.predictor = predictor
+
+    def reset(self):
+        # Recompute the distance map, if the environment has changed.
+        super().reset()
+
+    def get_many(self, handles=None):
+        '''
+        Because we do not want to call the predictor seperately for every agent we implement the get_many function
+        Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
+        :param handles:
+        :return:
+        '''
+
+        self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
+
+        self.predicted_pos = {}
+        for t in range(len(self.predictions[0])):
+            pos_list = []
+            for a in handles:
+                pos_list.append(self.predictions[a][t][1:3])
+            # We transform (x,y) coodrinates to a single integer number for simpler comparison
+            self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
+        observations = {}
+
+        # Collect all the different observation for all the agents
+        for h in handles:
+            observations[h] = self.get(h)
+        return observations
+
+    def get(self, handle):
+        '''
+        Lets write a simple observation which just indicates whether or not the own predicted path
+        overlaps with other predicted paths at any time. This is useless for the task of navigation but might
+        help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class
+
+        Each agent recieves an observation of length 10, where each element represents a prediction step and its value
+        is:
+         - 0 if no overlap is happening
+         - 1 where n i the number of other paths crossing the predicted cell
+
+        :param handle: handeled as an index of an agent
+        :return: Observation of handle
+        '''
+
+        observation = np.zeros(10)
+
+        # We are going to track what cells where considered while building the obervation and make them accesible
+        # For rendering
+
+        visited = set()
+        for _idx in range(10):
+            # Check if any of the other prediction overlap with agents own predictions
+            x_coord = self.predictions[handle][_idx][1]
+            y_coord = self.predictions[handle][_idx][2]
+
+            # We add every observed cell to the observation rendering
+            visited.add((x_coord, y_coord))
+            if self.predicted_pos[_idx][handle] in np.delete(self.predicted_pos[_idx], handle, 0):
+                # We detect if another agent is predicting to pass through the same cell at the same predicted time
+                observation[handle] = 1
+
+        # This variable will be access by the renderer to visualize the observation
+        self.env.dev_obs_dict[handle] = visited
+
+        return observation
+
+
+# Initiate the Predictor
+CustomPredictor = ShortestPathPredictorForRailEnv(10)
+
+# Pass the Predictor to the observation builder
+CustomObsBuilder = ObservePredictions(CustomPredictor)
+
+# Initiate Environment
+env = RailEnv(width=10,
+              height=10,
+              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              number_of_agents=3,
+              obs_builder_object=CustomObsBuilder)
+
+obs = env.reset()
+env_renderer = RenderTool(env, gl="PILSVG")
+
+# We render the initial step and show the obsered cells as colored boxes
+env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
+
+action_dict = {}
+for step in range(100):
+    for a in range(env.get_num_agents()):
+        action = np.random.randint(0, 5)
+        action_dict[a] = action
+    obs, all_rewards, done, _ = env.step(action_dict)
+    print("Rewards: ", all_rewards, "  [done=", done, "]")
+    env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
+    time.sleep(0.5)
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 1d825ff16554ef859339f0a60f1dbbce28dc62ef..24c57277572c551d2ff8bcfa1739d7ac863a553b 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -86,6 +86,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
+
     def get(self, custom_args=None, handle=None):
         """
         Called whenever get_many in the observation build is called.