diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 41d5603dbf0f9c3ea38474f8e58bc78353a6f1f0..d0d4ba4806e953f1185121380107c1f9a4e5c357 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -548,9 +548,10 @@ class TreeObsForRailEnv(ObservationBuilder):
             depth += 1
             pow4 *= 4
         child_size = (len(tree) - num_features_per_node) // 4
-        tree_data = tree[0:4].tolist()
+        tree_data = tree[:4].tolist()
         distance_data = [tree[4]]
-        agent_data = tree[num_features_per_node - 3:num_features_per_node].tolist()
+        agent_data = tree[5:num_features_per_node].tolist()
+        print(agent_data)
         for children in range(4):
             child_tree = tree[(num_features_per_node + children * child_size):
                               (num_features_per_node + (children + 1) * child_size)]