diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index e6b7a4f0b4ad84e6b07aa9395efc0aae2c5e5451..5c4ed86bdb4517d9501a887a6be6f256a7297e7e 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -468,7 +468,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                         prompt=prompt_[children],
                                         current_depth=current_depth + 1)
 
-    def split_tree(self, tree, num_features_per_node=5, current_depth=0):
+    def split_tree(self, tree, num_features_per_node=7, current_depth=0):
         """
 
         :param tree:
@@ -490,17 +490,19 @@ class TreeObsForRailEnv(ObservationBuilder):
             pow4 *= 4
         child_size = (len(tree) - num_features_per_node) // 4
         tree_data = tree[0:num_features_per_node - 1].tolist()
-        distance_data = [tree[num_features_per_node - 1]]
+        distance_data = [tree[num_features_per_node - 3]]
+        agent_data = tree[-2:]
         for children in range(4):
             child_tree = tree[(num_features_per_node + children * child_size):
                               (num_features_per_node + (children + 1) * child_size)]
-            tmp_tree_data, tmp_distance_data = self.split_tree(child_tree,
+            tmp_tree_data, tmp_distance_data, tmp_agent_data = self.split_tree(child_tree,
                                                                num_features_per_node,
                                                                current_depth=current_depth + 1)
             if len(tmp_tree_data) > 0:
                 tree_data.extend(tmp_tree_data)
                 distance_data.extend(tmp_distance_data)
-        return tree_data, distance_data
+                agent_data.extrend(tmp_agent_data)
+        return tree_data, distance_data, agent_data
 
 
 class GlobalObsForRailEnv(ObservationBuilder):