diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 483ae3053a7884f84668c653d6672ce33982b8b7..30c0fabeae0d7f9dfd49ba05983ad90f457cd5f2 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -16,17 +16,17 @@ from flatland.utils.ordered_set import OrderedSet
 
 class TreeObsForRailEnv(ObservationBuilder):
 
-    Node = collections.namedtuple('Node', 'dist_1 '
-                                          'dist_2 '
-                                          'dist_3 '
-                                          'dist_4 '
-                                          'dist_5 '
-                                          'dist_6 '
-                                          'dist_7 '
-                                          'num_agents_8 '
-                                          'num_agents_9 '
-                                          'num_agents_10 '
-                                          'speed_11 '
+    Node = collections.namedtuple('Node', 'dist_own_target_encountered '
+                                          'dist_other_target_encountered '
+                                          'dist_other_agent_encountered '
+                                          'dist_potential_conflict '
+                                          'dist_unusable_switch '
+                                          'dist_to_next_branch '
+                                          'dist_min_to_target '
+                                          'num_agents_same_direction '
+                                          'num_agents_opposite_direction '
+                                          'num_agents_malfunctioning '
+                                          'speed_min_fractional '
                                           'childs')
     """
     TreeObsForRailEnv object.
@@ -53,7 +53,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent_direction = {}
         self.predictor = predictor
         self.location_has_target = None
-        self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
 
     def reset(self):
@@ -181,12 +180,15 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Here information about the agent itself is stored
         distance_map = self.env.distance_map.get()
 
-        root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0,
-                                         dist_7=distance_map[(handle, *agent.position, agent.direction)],
-                                         num_agents_8=0, num_agents_9=0,
-                                         num_agents_10=agent.malfunction_data['malfunction'],
-                                         speed_11=agent.speed_data['speed'],
-                                         childs={})
+        root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
+                                                       dist_other_agent_encountered=0, dist_potential_conflict=0,
+                                                       dist_unusable_switch=0, dist_to_next_branch=0,
+                                                       dist_min_to_target=distance_map[(handle, *agent.position,
+                                                                                        agent.direction)],
+                                                       num_agents_same_direction=0, num_agents_opposite_direction=0,
+                                                       num_agents_malfunctioning=agent.malfunction_data['malfunction'],
+                                                       speed_min_fractional=agent.speed_data['speed'],
+                                                       childs={})
 
         visited = OrderedSet()
 
@@ -215,15 +217,6 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         return root_node_observation
 
-    def _num_cells_to_fill_in(self, remaining_depth):
-        """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
-        num_observations = 0
-        pow4 = 1
-        for i in range(remaining_depth):
-            num_observations += pow4
-            pow4 *= 4
-        return num_observations * self.observation_dim
-
     def _explore_branch(self, handle, position, direction, tot_dist, depth):
         """
         Utility function to compute tree-based observations.
@@ -398,37 +391,28 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Modify here to append new / different features for each visited cell!
 
         if last_is_target:
-            node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
-                                        dist_3=other_agent_encountered, dist_4=potential_conflict,
-                                        dist_5=unusable_switch, dist_6=tot_dist,
-                                        dist_7=0,
-                                        num_agents_8=other_agent_same_direction,
-                                        num_agents_9=other_agent_opposite_direction,
-                                        num_agents_10=malfunctioning_agent,
-                                        speed_11=min_fractional_speed,
-                                        childs={})
-
+            dist_to_next_branch = tot_dist,
+            dist_min_to_target = 0,
         elif last_is_terminal:
-            node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
-                                        dist_3=other_agent_encountered, dist_4=potential_conflict,
-                                        dist_5=unusable_switch, dist_6=np.inf,
-                                        dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
-                                        num_agents_8=other_agent_same_direction,
-                                        num_agents_9=other_agent_opposite_direction,
-                                        num_agents_10=malfunctioning_agent,
-                                        speed_11=min_fractional_speed,
-                                        childs={})
-
+            dist_to_next_branch = np.inf,
+            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction],
         else:
-            node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
-                                        dist_3=other_agent_encountered, dist_4=potential_conflict,
-                                        dist_5=unusable_switch, dist_6=tot_dist,
-                                        dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
-                                        num_agents_8=other_agent_same_direction,
-                                        num_agents_9=other_agent_opposite_direction,
-                                        num_agents_10=malfunctioning_agent,
-                                        speed_11=min_fractional_speed,
-                                        childs={})
+            dist_to_next_branch = tot_dist,
+            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction],
+
+        node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
+                                      dist_other_target_encountered=other_target_encountered,
+                                      dist_other_agent_encountered=other_agent_encountered,
+                                      dist_potential_conflict=potential_conflict,
+                                      dist_unusable_switch=unusable_switch,
+                                      dist_to_next_branch=dist_to_next_branch,
+                                      dist_min_to_target=dist_min_to_target,
+                                      num_agents_same_direction=other_agent_same_direction,
+                                      num_agents_opposite_direction=other_agent_opposite_direction,
+                                      num_agents_malfunctioning=malfunctioning_agent,
+                                      speed_min_fractional=min_fractional_speed,
+                                      childs={})
+
         # #############################
         # #############################
         # Start from the current orientation, and see which transitions are available;
@@ -475,10 +459,13 @@ class TreeObsForRailEnv(ObservationBuilder):
         for direction in self.tree_explorted_actions_char:
             self.print_subtree(tree.childs[direction], direction, "\t")
 
-    def print_node_features(self, node: Node, label, indent):
-        print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4,
-              ", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9,
-              ", ", node.num_agents_10, ", ", node.speed_11)
+    @staticmethod
+    def print_node_features(node: Node, label, indent):
+        print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
+              node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
+              node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
+              node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
+              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
 
     def print_subtree(self, node, label, indent):
         if node == -np.inf or not node:
@@ -493,37 +480,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         for direction in self.tree_explorted_actions_char:
             self.print_subtree(node.childs[direction], direction, indent + "\t")
 
-
-    def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        if len(tree) < self.observation_dim:
-            return
-
-        depth = 0
-        tmp = len(tree) / self.observation_dim - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-
-        unfolded = {}
-        unfolded[''] = tree[0:self.observation_dim]
-        child_size = (len(tree) - self.observation_dim) // 4
-        for child in range(4):
-            child_tree = tree[(self.observation_dim + child * child_size):
-                              (self.observation_dim + (child + 1) * child_size)]
-            observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
-            if observation_tree is not None:
-                if actions_for_display:
-                    label = self.tree_explorted_actions_char[child]
-                else:
-                    label = self.tree_explored_actions[child]
-                unfolded[label] = observation_tree
-        return unfolded
-
     def set_env(self, env: Environment):
         super().set_env(env)
         if self.predictor:
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 7c5e685f3940e0e81859d02dac3752608b5771b7..9ef122e6aa8325d3ec70307332cc7235090270de 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -277,17 +277,17 @@ def test_shortest_path_predictor_conflicts(rendering=False):
 
 
 def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
-    assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt)
+    assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
     for a_1 in obs_builder.tree_explorted_actions_char:
         if tree.childs[a_1] == -np.inf:
             assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
             continue
         else:
-            conflict = tree.childs[a_1].num_agents_9
+            conflict = tree.childs[a_1].num_agents_opposite_direction
             assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
         for a_2 in obs_builder.tree_explorted_actions_char:
             if tree.childs[a_1].childs[a_2] == -np.inf:
                 assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
             else:
-                conflict = tree.childs[a_1].childs[a_2].num_agents_9
+                conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
                 assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)