diff --git a/examples/training_example.py b/examples/training_example.py
index 1342107767a599cb1440fd08f506903a5059492e..ad222cb0889a92023c726f68ce8ecff0d2dd6cee 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -1,6 +1,8 @@
 import numpy as np
 
 from flatland.envs.generators import complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 
 np.random.seed(1)
@@ -8,10 +10,13 @@ np.random.seed(1)
 # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
 # Training on simple small tasks is the best way to get familiar with the environment
 #
-env = RailEnv(width=15,
-              height=15,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
-              number_of_agents=5)
+
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv())
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              obs_builder_object=TreeObservation,
+              number_of_agents=2)
 
 
 # Import your own Agent or use RLlib to train agents on Flatland
@@ -56,6 +61,7 @@ n_trials = 5
 # Empty dictionary for all agent action
 action_dict = dict()
 print("Starting Training...")
+
 for trials in range(1, n_trials + 1):
 
     # Reset environment and get initial observations for all agents
@@ -74,7 +80,8 @@ for trials in range(1, n_trials + 1):
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
-
+        TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
+        print(len(next_obs[0]))
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index 35e6560818acb0d5d7b2bf56bfd8feb33f26c87a..c9595b7693497daa7db110b8fc8b4ae040d39cc9 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -89,10 +89,12 @@ def coordinate_to_position(width, coords):
     :param coords:
     :return:
     """
-    position = []
+    position = np.empty(len(coords), dtype=int)
+    idx = 0
     for t in coords:
-        position.append((t[1] * width + t[0]))
-    return np.asarray(position).flatten()
+        position[idx] = int(t[1] * width + t[0])
+        idx += 1
+    return position
 
 
 class AStarNode():
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 7bc2f994f8ae05f86e569cdb64bae08f6dc877f2..e0f0c7f1b62bcd81ed2db5eb8f7bbe67bcee8bcf 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -27,7 +27,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i in range(self.max_depth + 1):
             size += pow4
             pow4 *= 4
-        self.observation_dim = 7
+        self.observation_dim = 8
         self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
@@ -187,10 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     dir_list.append(self.predictions[a][t][3])
                 self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                 self.predicted_dir.update({t: dir_list})
-
-            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0)
-            pred_pos = list(map(list, zip(*pred_pos)))
-
+            self.max_prediction_depth = len(self.predicted_pos)
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
@@ -256,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
-        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
 
         root_observation = observation[:]
         visited = set()
@@ -309,7 +306,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         other_target_encountered = np.inf
         other_agent_same_direction = 0
         other_agent_opposite_direction = 0
-
+        potential_conflict = 0
         num_steps = 1
         while exploring:
             # #############################
@@ -329,6 +326,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                     other_agent_opposite_direction += 1
 
             # Register possible conflict
+            if self.predictor and num_steps < self.max_prediction_depth:
+                if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps],
+                                                                                   handle):
+                    potential_conflict = 1
 
             if position in self.location_has_target:
                 if num_steps < other_target_encountered:
@@ -430,7 +431,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            root_observation[3] + num_steps,
                            0,
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           potential_conflict
                            ]
 
         elif last_isTerminal:
@@ -440,7 +442,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            np.inf,
                            np.inf,
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           potential_conflict
                            ]
         else:
             observation = [0,
@@ -449,7 +452,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            root_observation[3] + num_steps,
                            self.distance_map[handle, position[0], position[1], direction],
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           potential_conflict
                            ]
         # #############################
         # #############################
@@ -493,7 +497,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         return observation, visited
 
-    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
+    def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0):
         """
         Utility function to pretty-print tree observations returned by this object.
         """
@@ -520,7 +524,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                         prompt=prompt_[children],
                                         current_depth=current_depth + 1)
 
-    def split_tree(self, tree, num_features_per_node=7, current_depth=0):
+    def split_tree(self, tree, num_features_per_node=8, current_depth=0):
         """
 
         :param tree:
@@ -541,9 +545,10 @@ class TreeObsForRailEnv(ObservationBuilder):
             depth += 1
             pow4 *= 4
         child_size = (len(tree) - num_features_per_node) // 4
-        tree_data = tree[0:num_features_per_node - 3].tolist()
-        distance_data = [tree[num_features_per_node - 3]]
-        agent_data = tree[-2:].tolist()
+        tree_data = tree[0:4].tolist()
+        distance_data = [tree[4]]
+        agent_data = tree[-3:].tolist()
+
         for children in range(4):
             child_tree = tree[(num_features_per_node + children * child_size):
                               (num_features_per_node + (children + 1) * child_size)]