diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index de9ee2a45ebdeabec5202be4f12593a82b4e20e4..483ae3053a7884f84668c653d6672ce33982b8b7 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -1,8 +1,8 @@
 """
 Collection of environment-specific ObservationBuilder.
 """
-import pprint
-from typing import Optional, List, Dict, T, Tuple
+import collections
+from typing import Optional, List, Dict, Tuple
 
 import numpy as np
 
@@ -15,6 +15,19 @@ from flatland.utils.ordered_set import OrderedSet
 
 
 class TreeObsForRailEnv(ObservationBuilder):
+
+    Node = collections.namedtuple('Node', 'dist_1 '
+                                          'dist_2 '
+                                          'dist_3 '
+                                          'dist_4 '
+                                          'dist_5 '
+                                          'dist_6 '
+                                          'dist_7 '
+                                          'num_agents_8 '
+                                          'num_agents_9 '
+                                          'num_agents_10 '
+                                          'speed_11 '
+                                          'childs')
     """
     TreeObsForRailEnv object.
 
@@ -165,11 +178,15 @@ class TreeObsForRailEnv(ObservationBuilder):
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
-        # Root node - current position
         # Here information about the agent itself is stored
         distance_map = self.env.distance_map.get()
-        observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0,
-                       agent.malfunction_data['malfunction'], agent.speed_data['speed']]
+
+        root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0,
+                                         dist_7=distance_map[(handle, *agent.position, agent.direction)],
+                                         num_agents_8=0, num_agents_9=0,
+                                         num_agents_10=agent.malfunction_data['malfunction'],
+                                         speed_11=agent.speed_data['speed'],
+                                         childs={})
 
         visited = OrderedSet()
 
@@ -181,19 +198,22 @@ class TreeObsForRailEnv(ObservationBuilder):
         if num_transitions == 1:
             orientation = np.argmax(possible_transitions)
 
-        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
+        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
+
             if possible_transitions[branch_direction]:
                 new_cell = get_new_position(agent.position, branch_direction)
+
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
-                observation = observation + branch_observation
+                root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
+
                 visited |= branch_visited
             else:
                 # add cells filled with infinity if no transition is possible
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
+                root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
         self.env.dev_obs_dict[handle] = visited
 
-        return observation
+        return root_node_observation
 
     def _num_cells_to_fill_in(self, remaining_depth):
         """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
@@ -378,53 +398,44 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Modify here to append new / different features for each visited cell!
 
         if last_is_target:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           0,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
+            node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
+                                        dist_3=other_agent_encountered, dist_4=potential_conflict,
+                                        dist_5=unusable_switch, dist_6=tot_dist,
+                                        dist_7=0,
+                                        num_agents_8=other_agent_same_direction,
+                                        num_agents_9=other_agent_opposite_direction,
+                                        num_agents_10=malfunctioning_agent,
+                                        speed_11=min_fractional_speed,
+                                        childs={})
 
         elif last_is_terminal:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           np.inf,
-                           self.env.distance_map.get()[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
+            node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
+                                        dist_3=other_agent_encountered, dist_4=potential_conflict,
+                                        dist_5=unusable_switch, dist_6=np.inf,
+                                        dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
+                                        num_agents_8=other_agent_same_direction,
+                                        num_agents_9=other_agent_opposite_direction,
+                                        num_agents_10=malfunctioning_agent,
+                                        speed_11=min_fractional_speed,
+                                        childs={})
 
         else:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           self.env.distance_map.get()[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
+            node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
+                                        dist_3=other_agent_encountered, dist_4=potential_conflict,
+                                        dist_5=unusable_switch, dist_6=tot_dist,
+                                        dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
+                                        num_agents_8=other_agent_same_direction,
+                                        num_agents_9=other_agent_opposite_direction,
+                                        num_agents_10=malfunctioning_agent,
+                                        speed_11=min_fractional_speed,
+                                        childs={})
         # #############################
         # #############################
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # Get the possible transitions
         possible_transitions = self.env.rail.get_transitions(*position, direction)
-        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
+        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
             if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                  (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
@@ -435,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           (branch_direction + 2) % 4,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                observation = observation + branch_observation
+                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             elif last_is_switch and possible_transitions[branch_direction]:
@@ -445,21 +456,43 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           branch_direction,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                observation = observation + branch_observation
+                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             else:
                 # no exploring possible, add just cells with infinity
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
+                node.childs[self.tree_explorted_actions_char[i]] = -np.inf
 
-        return observation, visited
+        if depth == self.max_depth:
+            node.childs.clear()
+        return node, visited
 
-    def util_print_obs_subtree(self, tree):
+    def util_print_obs_subtree(self, tree: Node):
         """
-        Utility function to pretty-print tree observations returned by this object.
+        Utility function to print tree observations returned by this object.
         """
-        pp = pprint.PrettyPrinter(indent=4)
-        pp.pprint(self.unfold_observation_tree(tree))
+        self.print_node_features(tree, "root", "")
+        for direction in self.tree_explorted_actions_char:
+            self.print_subtree(tree.childs[direction], direction, "\t")
+
+    def print_node_features(self, node: Node, label, indent):
+        print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4,
+              ", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9,
+              ", ", node.num_agents_10, ", ", node.speed_11)
+
+    def print_subtree(self, node, label, indent):
+        if node == -np.inf or not node:
+            print(indent, "Direction ", label, ": -np.inf")
+            return
+
+        self.print_node_features(node, label, indent)
+
+        if not node.childs:
+            return
+
+        for direction in self.tree_explorted_actions_char:
+            self.print_subtree(node.childs[direction], direction, indent + "\t")
+
 
     def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
         """
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index c31494673e63a17dc07eb6d89eeb581c640b1e13..7c5e685f3940e0e81859d02dac3752608b5771b7 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -264,9 +264,10 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     # get the trees to test
     obs_builder: TreeObsForRailEnv = env.obs_builder
     pp = pprint.PrettyPrinter(indent=4)
-    tree_0 = obs_builder.unfold_observation_tree(observations[0])
-    tree_1 = obs_builder.unfold_observation_tree(observations[1])
-    pp.pprint(tree_0)
+    tree_0 = observations[0]
+    tree_1 = observations[1]
+    env.obs_builder.util_print_obs_subtree(tree_0)
+    env.obs_builder.util_print_obs_subtree(tree_1)
 
     # check the expectations
     expected_conflicts_0 = [('F', 'R')]
@@ -275,11 +276,18 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
 
 
-def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
-    assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
+def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
+    assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt)
     for a_1 in obs_builder.tree_explorted_actions_char:
-        conflict = tree_0[a_1][''][8]
-        assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
+        if tree.childs[a_1] == -np.inf:
+            assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
+            continue
+        else:
+            conflict = tree.childs[a_1].num_agents_9
+            assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
         for a_2 in obs_builder.tree_explorted_actions_char:
-            conflict = tree_0[a_1][a_2][''][8]
-            assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
+            if tree.childs[a_1].childs[a_2] == -np.inf:
+                assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
+            else:
+                conflict = tree.childs[a_1].childs[a_2].num_agents_9
+                assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)