diff --git a/flatland/core/env.py b/flatland/core/env.py
index 32691f507f4cb5586f10b5645cc22ece718edc21..1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -84,27 +84,6 @@ class Environment:
         """
         raise NotImplementedError()
 
-    def predict(self):
-        """
-        Predictions step.
-
-        Returns predictions for the agents.
-        The returns are dicts mapping from agent_id strings to values.
-
-        Returns
-        -------
-        predictions : dict
-            New predictions for each ready agent.
-
-        """
-        raise NotImplementedError()
-
-    def render(self):
-        """
-        Perform rendering of the environment.
-        """
-        raise NotImplementedError()
-
     def get_agent_handles(self):
         """
         Returns a list of agents' handles to be used as keys in the step()
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index b30c2b1f5ddab079c9b6c41e35f03c69ed4162c3..53e7a068b73f9907217777251bce0fdd704603be 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -30,6 +30,27 @@ class ObservationBuilder:
         """
         raise NotImplementedError()
 
+    def get_many(self, handles=[]):
+        """
+        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
+        in the `handles' list.
+
+        Parameters
+        -------
+        handles : list of handles (optional)
+            List with the handles of the agents for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            A dictionary of observation structures, specific to the corresponding environment, with handles from
+            `handles' as keys.
+        """
+        observations = {}
+        for h in handles:
+            observations[h] = self.get(h)
+        return observations
+
     def get(self, handle=0):
         """
         Called whenever an observation has to be computed for the `env' environment, possibly
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 676051d8338534c704ef98c90bee08a2836d4cfb..541f8ad592d1481afb8eb6da2eb7b887aacae419 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -17,7 +17,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     network to simplify the representation of the state of the environment for each agent.
     """
 
-    def __init__(self, max_depth):
+    def __init__(self, max_depth, predictor=None):
         self.max_depth = max_depth
 
         # Compute the size of the returned observation vector
@@ -30,7 +30,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
-
+        self.predictor = predictor
         self.agents_previous_reset = None
 
     def reset(self):
@@ -167,6 +167,21 @@ class TreeObsForRailEnv(ObservationBuilder):
         elif movement == 3:  # WEST
             return (position[0], position[1] - 1)
 
+    def get_many(self, handles=[]):
+        """
+        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
+        in the `handles' list.
+        """
+
+        self.predictions = []
+        if self.predictor:
+            for a in range(len(handles)):
+                self.predictions.append(self.predictor.get(a))
+        observations = {}
+        for h in handles:
+            observations[h] = self.get(h)
+        return observations
+
     def get(self, handle):
         """
         Computes the current observation for agent `handle' in env
@@ -207,6 +222,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                 (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
             0 = no agent present other direction than myself
 
+        #8: possible conflict detected
+
 
         Missing/padding nodes are filled in with -inf (truncated).
         Missing values in present node are filled in with +inf (truncated).
@@ -241,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
-
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
                 observation = observation + branch_observation
@@ -524,6 +540,11 @@ class TreeObsForRailEnv(ObservationBuilder):
                 agent_data.extend(tmp_agent_data)
         return tree_data, distance_data, agent_data
 
+    def _set_env(self, env):
+        self.env = env
+        if self.predictor:
+            self.predictor._set_env(self.env)
+
 
 class GlobalObsForRailEnv(ObservationBuilder):
     """
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 795aabba67bf0deaf3d73c69a74788b9527abb58..c22e1c5120b54a170f9c59bb54c7666ca910f086 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -58,7 +58,6 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
-                 prediction_builder_object=None
                  ):
         """
         Environment init.
@@ -99,10 +98,6 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
-        self.prediction_builder = prediction_builder_object
-        if self.prediction_builder:
-            self.prediction_builder._set_env(self)
-
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
@@ -297,11 +292,6 @@ class RailEnv(Environment):
             np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
         return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
 
-    def predict(self):
-        if not self.prediction_builder:
-            return {}
-        return self.prediction_builder.get()
-
     def check_action(self, agent, action):
         transition_isValid = None
         possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
@@ -330,21 +320,9 @@ class RailEnv(Environment):
         return new_direction, transition_isValid
 
     def _get_observations(self):
-        self.obs_dict = {}
-        self.debug_obs_dict = {}
-        for iAgent in range(self.get_num_agents()):
-            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
+        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
-    def _get_predictions(self):
-        if not self.prediction_builder:
-            return {}
-        return {}
-
-    def render(self):
-        # TODO:
-        pass
-
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index be804c0348cf0fca9cc0c089c64693f7a617064b..8e934904f86d2da0dae6d59038a2e0b499415271 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -383,7 +383,9 @@ class PILSVG(PILGL):
             (0, 3): "Zug_2_Weiche_#0091ea.svg"
         }
 
-        # "paint" color of the train images we load
+        # "paint" color of the train images we load - this is the color we will change.
+        # a3BaseColor = self.rgb_s2i("0091ea") \#  noqa: E800
+        # temporary workaround for trains / agents renamed with different colour:
         a3BaseColor = self.rgb_s2i("d50000")
 
         self.dPilZug = {}
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index 35a6a27b970ce54e1cabd3cf8c80d30a34800a25..a1c951d31f7121a8445f4309c31ce58653c7e463 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
-from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 
@@ -64,8 +64,7 @@ def test_predictions():
                   height=rail_map.shape[0],
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
                   number_of_agents=1,
-                  obs_builder_object=GlobalObsForRailEnv(),
-                  prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=20)),
                   )
 
     env.reset()
@@ -74,7 +73,7 @@ def test_predictions():
     env.agents[0].position = (5, 6)
     env.agents[0].direction = 0
 
-    predictions = env.predict()
+    predictions = env.obs_builder.predictor.get()
     positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))