diff --git a/flatland/core/env.py b/flatland/core/env.py
index 3618d965a39b5a71fd1cf24aa81f2f876d5c6365..1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -84,21 +84,6 @@ class Environment:
         """
         raise NotImplementedError()
 
-    def predict(self):
-        """
-        Predictions step.
-
-        Returns predictions for the agents.
-        The returns are dicts mapping from agent_id strings to values.
-
-        Returns
-        -------
-        predictions : dict
-            New predictions for each ready agent.
-
-        """
-        raise NotImplementedError()
-
     def get_agent_handles(self):
         """
         Returns a list of agents' handles to be used as keys in the step()
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index a3d88d773db9edaa0777e2aee94593a0392a956c..541f8ad592d1481afb8eb6da2eb7b887aacae419 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
         self.predictor = predictor
-
         self.agents_previous_reset = None
 
     def reset(self):
@@ -174,9 +173,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         in the `handles' list.
         """
 
-        # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
+        self.predictions = []
         if self.predictor:
-            print(self.predictor.get(0))
+            for a in range(len(handles)):
+                self.predictions.append(self.predictor.get(a))
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
@@ -222,6 +222,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                 (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
             0 = no agent present other direction than myself
 
+        #8: possible conflict detected
+
 
         Missing/padding nodes are filled in with -inf (truncated).
         Missing values in present node are filled in with +inf (truncated).
@@ -256,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
-
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
                 observation = observation + branch_observation
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 5d20a5d9f38230f353b0a9616c49ede333206c49..7773f86c1407153f649c972398c9a58067a38947 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -292,7 +292,6 @@ class RailEnv(Environment):
             np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
         return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
 
-
     def check_action(self, agent, action):
         transition_isValid = None
         possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
@@ -324,7 +323,6 @@ class RailEnv(Environment):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
-
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index 5b0e830bbe14b081c53ce8a31ef2f9db270b62e3..ae910a369da849a86d39721d83b78ba777086d00 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
-from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 
@@ -64,8 +64,7 @@ def test_predictions():
                   height=rail_map.shape[0],
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
                   number_of_agents=1,
-                  obs_builder_object=GlobalObsForRailEnv(),
-                  prediction_builder_object=DummyPredictorForRailEnv(max_depth=10)
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=20)),
                   )
 
     env.reset()
@@ -75,7 +74,7 @@ def test_predictions():
     env.agents[0].direction = 0
     env.agents[0].target = (3., 0.)
 
-    predictions = env.predict()
+    predictions = env.obs_builder.predictor.get()
     positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))