diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index f3c0a7c3c06ff27ecd956e1734a6fa33fd1236b3..a3d88d773db9edaa0777e2aee94593a0392a956c 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -17,7 +17,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     network to simplify the representation of the state of the environment for each agent.
     """
 
-    def __init__(self, max_depth):
+    def __init__(self, max_depth, predictor=None):
         self.max_depth = max_depth
 
         # Compute the size of the returned observation vector
@@ -30,6 +30,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
+        self.predictor = predictor
 
         self.agents_previous_reset = None
 
@@ -174,7 +175,8 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
 
         # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
-
+        if self.predictor:
+            print(self.predictor.get(0))
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
@@ -537,6 +539,11 @@ class TreeObsForRailEnv(ObservationBuilder):
                 agent_data.extend(tmp_agent_data)
         return tree_data, distance_data, agent_data
 
+    def _set_env(self, env):
+        self.env = env
+        if self.predictor:
+            self.predictor._set_env(self.env)
+
 
 class GlobalObsForRailEnv(ObservationBuilder):
     """
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 44ed3f770600ba1aae1745926109deb1fa7398ef..5d20a5d9f38230f353b0a9616c49ede333206c49 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -58,7 +58,6 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
-                 prediction_builder_object=None
                  ):
         """
         Environment init.
@@ -99,10 +98,6 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
-        self.prediction_builder = prediction_builder_object
-        if self.prediction_builder:
-            self.prediction_builder._set_env(self)
-
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
@@ -297,10 +292,6 @@ class RailEnv(Environment):
             np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
         return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
 
-    def predict(self):
-        if not self.prediction_builder:
-            return {}
-        return self.prediction_builder.get()
 
     def check_action(self, agent, action):
         transition_isValid = None
@@ -333,10 +324,6 @@ class RailEnv(Environment):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
-    def _get_predictions(self):
-        if not self.prediction_builder:
-            return {}
-        return {}
 
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()