diff --git a/docs/tutorials/01_gettingstarted.rst b/docs/tutorials/01_gettingstarted.rst
index 9ca370a0e1ed6a698d3e37111c70b97ac0ad2aa8..a7a8e5514e03d2f9feb31408d433badc7f3767c2 100644
--- a/docs/tutorials/01_gettingstarted.rst
+++ b/docs/tutorials/01_gettingstarted.rst
@@ -160,7 +160,7 @@ Once we are set with the environment we can load our preferred agent from either
 
 .. code-block:: python
 
-    agent = RandomAgent(env.action_space, env.observation_space)
+    agent = RandomAgent(state_size, action_size)
 
 We start every trial by resetting the environment
 
diff --git a/docs/tutorials/02_observationbuilder.rst b/docs/tutorials/02_observationbuilder.rst
index fd5decae31139a883289155131d71fd6870e0a74..d1c287fedf880fd091a1b24921292010eb01359e 100644
--- a/docs/tutorials/02_observationbuilder.rst
+++ b/docs/tutorials/02_observationbuilder.rst
@@ -18,7 +18,7 @@ base class and must implement two methods, :code:`reset(self)` and :code:`get(se
 
 .. _`flatland.core.env_observation_builder.ObservationBuilder` : https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/core/env_observation_builder.py#L13
 
-Below is a simple example that returns observation vectors of size :code:`observation_space = 5` featuring only the ID (handle) of the agent whose
+Below is a simple example that returns observation vectors of size 5 featuring only the ID (handle) of the agent whose
 observation vector is being computed:
 
 .. code-block:: python
@@ -28,14 +28,12 @@ observation vector is being computed:
         Simplest observation builder. The object returns observation vectors with 5 identical components,
         all equal to the ID of the respective agent.
         """
-        def __init__(self):
-            self.observation_space = [5]
 
         def reset(self):
             return
 
         def get(self, handle):
-            observation = handle * np.ones((self.observation_space[0],))
+            observation = handle * np.ones(5)
             return observation
 
 We can pass an instance of our custom observation builder :code:`SimpleObs` to the :code:`RailEnv` creator as follows:
@@ -85,7 +83,6 @@ Note that this simple strategy fails when multiple agents are present, as each a
             super().__init__(max_depth=0)
             # We set max_depth=0 in because we only need to look at the current
             # position of the agent to decide what direction is shortest.
-            self.observation_space = [3]
 
         def reset(self):
             # Recompute the distance map, if the environment has changed.
@@ -189,7 +186,6 @@ In contrast to the previous examples we also implement the :code:`def get_many(s
 
         def __init__(self, predictor):
             super().__init__(max_depth=0)
-            self.observation_space = [10]
             self.predictor = predictor
 
         def reset(self):
diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py
index 705169e95c71137e93f92e8026f82a34d29d2182..8c12886777c0d47952db4aeea8314488d13c19c3 100644
--- a/examples/custom_observation_example_01_SimpleObs.py
+++ b/examples/custom_observation_example_01_SimpleObs.py
@@ -18,7 +18,6 @@ class SimpleObs(ObservationBuilder):
 
     def __init__(self):
         super().__init__()
-        self.observation_space = [5]
 
     def reset(self):
         return
diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index e9c2a84eea5e375c024b96a35934e136bb5d40b5..7ddfcd899f747f038471cbe3921e6df76fff37ee 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -28,7 +28,6 @@ class SingleAgentNavigationObs(ObservationBuilder):
 
     def __init__(self):
         super().__init__()
-        self.observation_space = [3]
 
     def reset(self):
         pass
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 2ed47a5f8c18894c8e0a2108b7bc0a8d73b783f9..855d1f5dffbef29da26ca1dead933af846b863bd 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -28,7 +28,6 @@ class ObservePredictions(ObservationBuilder):
 
     def __init__(self, predictor):
         super().__init__()
-        self.observation_space = [10]
         self.predictor = predictor
 
     def reset(self):
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 7cb7d9623c79d154e549114e779d39c138cf788d..a52eeed47c5cb1fe75c87d430d93f30f50336fbf 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -25,7 +25,6 @@ class SingleAgentNavigationObs(ObservationBuilder):
 
     def __init__(self):
         super().__init__()
-        self.observation_space = [3]
 
     def reset(self):
         pass
diff --git a/flatland/core/env.py b/flatland/core/env.py
index 2dc983c5edd49e4ebb5033877b69f36f293141b2..32b688ca78e35b1e36aac85c0da4a4ee22246d1b 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -11,7 +11,6 @@ class Environment:
 
     Derived environments should implement the following attributes:
         action_space: tuple with the dimensions of the actions to be passed to the step method
-        observation_space: tuple with the dimensions of the observations returned by reset and step
 
     Agents are identified by agent ids (handles).
     Examples:
@@ -46,7 +45,6 @@ class Environment:
 
     def __init__(self):
         self.action_space = ()
-        self.observation_space = ()
         pass
 
     def reset(self):
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 2d4df089eed08ee17f3d5f89147735b1b8570a7d..3cc21966162dd28d183493a97cd6072a34abb738 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -18,13 +18,9 @@ from flatland.core.env import Environment
 class ObservationBuilder:
     """
     ObservationBuilder base class.
-
-    Derived objects must implement and `observation_space` attribute as a tuple with the dimensions of the returned
-    observations.
     """
 
     def __init__(self):
-        self.observation_space = ()
         self.env = None
 
     def set_env(self, env: Environment):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index de9ee2a45ebdeabec5202be4f12593a82b4e20e4..c23d4345a03c761ad4c4ac1d936db817f8acc529 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -1,8 +1,8 @@
 """
 Collection of environment-specific ObservationBuilder.
 """
-import pprint
-from typing import Optional, List, Dict, T, Tuple
+import collections
+from typing import Optional, List, Dict, Tuple
 
 import numpy as np
 
@@ -15,6 +15,22 @@ from flatland.utils.ordered_set import OrderedSet
 
 
 class TreeObsForRailEnv(ObservationBuilder):
+
+    Node = collections.namedtuple('Node', 'dist_own_target_encountered '
+                                          'dist_other_target_encountered '
+                                          'dist_other_agent_encountered '
+                                          'dist_potential_conflict '
+                                          'dist_unusable_switch '
+                                          'dist_to_next_branch '
+                                          'dist_min_to_target '
+                                          'num_agents_same_direction '
+                                          'num_agents_opposite_direction '
+                                          'num_agents_malfunctioning '
+                                          'speed_min_fractional '
+                                          'childs')
+
+    tree_explorted_actions_char = ['L', 'F', 'R', 'B']
+
     """
     TreeObsForRailEnv object.
 
@@ -29,24 +45,15 @@ class TreeObsForRailEnv(ObservationBuilder):
         super().__init__()
         self.max_depth = max_depth
         self.observation_dim = 11
-        # Compute the size of the returned observation vector
-        size = 0
-        pow4 = 1
-        for i in range(self.max_depth + 1):
-            size += pow4
-            pow4 *= 4
-        self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
         self.predictor = predictor
         self.location_has_target = None
-        self.tree_explored_actions = [1, 2, 3, 0]
-        self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
 
     def reset(self):
         self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
-    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]:
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
@@ -75,7 +82,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle: int = 0) -> List[int]:
+    def get(self, handle: int = 0) -> Node:
         """
         Computes the current observation for agent `handle` in env
 
@@ -165,11 +172,18 @@ class TreeObsForRailEnv(ObservationBuilder):
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
-        # Root node - current position
         # Here information about the agent itself is stored
         distance_map = self.env.distance_map.get()
-        observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0,
-                       agent.malfunction_data['malfunction'], agent.speed_data['speed']]
+
+        root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
+                                                       dist_other_agent_encountered=0, dist_potential_conflict=0,
+                                                       dist_unusable_switch=0, dist_to_next_branch=0,
+                                                       dist_min_to_target=distance_map[(handle, *agent.position,
+                                                                                        agent.direction)],
+                                                       num_agents_same_direction=0, num_agents_opposite_direction=0,
+                                                       num_agents_malfunctioning=agent.malfunction_data['malfunction'],
+                                                       speed_min_fractional=agent.speed_data['speed'],
+                                                       childs={})
 
         visited = OrderedSet()
 
@@ -181,28 +195,22 @@ class TreeObsForRailEnv(ObservationBuilder):
         if num_transitions == 1:
             orientation = np.argmax(possible_transitions)
 
-        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
+        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
+
             if possible_transitions[branch_direction]:
                 new_cell = get_new_position(agent.position, branch_direction)
+
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
-                observation = observation + branch_observation
+                root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
+
                 visited |= branch_visited
             else:
                 # add cells filled with infinity if no transition is possible
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
+                root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
         self.env.dev_obs_dict[handle] = visited
 
-        return observation
-
-    def _num_cells_to_fill_in(self, remaining_depth):
-        """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
-        num_observations = 0
-        pow4 = 1
-        for i in range(remaining_depth):
-            num_observations += pow4
-            pow4 *= 4
-        return num_observations * self.observation_dim
+        return root_node_observation
 
     def _explore_branch(self, handle, position, direction, tot_dist, depth):
         """
@@ -378,53 +386,35 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Modify here to append new / different features for each visited cell!
 
         if last_is_target:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           0,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-
+            dist_to_next_branch = tot_dist
+            dist_min_to_target = 0
         elif last_is_terminal:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           np.inf,
-                           self.env.distance_map.get()[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-
+            dist_to_next_branch = np.inf
+            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
         else:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           self.env.distance_map.get()[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
+            dist_to_next_branch = tot_dist
+            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
+
+        node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
+                                      dist_other_target_encountered=other_target_encountered,
+                                      dist_other_agent_encountered=other_agent_encountered,
+                                      dist_potential_conflict=potential_conflict,
+                                      dist_unusable_switch=unusable_switch,
+                                      dist_to_next_branch=dist_to_next_branch,
+                                      dist_min_to_target=dist_min_to_target,
+                                      num_agents_same_direction=other_agent_same_direction,
+                                      num_agents_opposite_direction=other_agent_opposite_direction,
+                                      num_agents_malfunctioning=malfunctioning_agent,
+                                      speed_min_fractional=min_fractional_speed,
+                                      childs={})
+
         # #############################
         # #############################
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # Get the possible transitions
         possible_transitions = self.env.rail.get_transitions(*position, direction)
-        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
+        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
             if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                  (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
@@ -435,7 +425,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           (branch_direction + 2) % 4,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                observation = observation + branch_observation
+                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             elif last_is_switch and possible_transitions[branch_direction]:
@@ -445,51 +435,45 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           branch_direction,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                observation = observation + branch_observation
+                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             else:
                 # no exploring possible, add just cells with infinity
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
+                node.childs[self.tree_explorted_actions_char[i]] = -np.inf
 
-        return observation, visited
+        if depth == self.max_depth:
+            node.childs.clear()
+        return node, visited
 
-    def util_print_obs_subtree(self, tree):
+    def util_print_obs_subtree(self, tree: Node):
         """
-        Utility function to pretty-print tree observations returned by this object.
+        Utility function to print tree observations returned by this object.
         """
-        pp = pprint.PrettyPrinter(indent=4)
-        pp.pprint(self.unfold_observation_tree(tree))
+        self.print_node_features(tree, "root", "")
+        for direction in self.tree_explorted_actions_char:
+            self.print_subtree(tree.childs[direction], direction, "\t")
+
+    @staticmethod
+    def print_node_features(node: Node, label, indent):
+        print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
+              node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
+              node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
+              node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
+              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
+
+    def print_subtree(self, node, label, indent):
+        if node == -np.inf or not node:
+            print(indent, "Direction ", label, ": -np.inf")
+            return
 
-    def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        if len(tree) < self.observation_dim:
+        self.print_node_features(node, label, indent)
+
+        if not node.childs:
             return
 
-        depth = 0
-        tmp = len(tree) / self.observation_dim - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-
-        unfolded = {}
-        unfolded[''] = tree[0:self.observation_dim]
-        child_size = (len(tree) - self.observation_dim) // 4
-        for child in range(4):
-            child_tree = tree[(self.observation_dim + child * child_size):
-                              (self.observation_dim + (child + 1) * child_size)]
-            observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
-            if observation_tree is not None:
-                if actions_for_display:
-                    label = self.tree_explorted_actions_char[child]
-                else:
-                    label = self.tree_explored_actions[child]
-                unfolded[label] = observation_tree
-        return unfolded
+        for direction in self.tree_explorted_actions_char:
+            self.print_subtree(node.childs[direction], direction, indent + "\t")
 
     def set_env(self, env: Environment):
         super().set_env(env)
@@ -508,23 +492,21 @@ class GlobalObsForRailEnv(ObservationBuilder):
         - transition map array with dimensions (env.height, env.width, 16),\
           assuming 16 bits encoding of transitions.
 
+        - A 3D array (map_height, map_width, 4) with
+            - first channel containing the agents position and direction
+            - second channel containing the other agents positions and diretion
+            - third channel containing agent/other agent malfunctions
+            - fourth channel containing agent/other agent fractional speeds
+
         - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
          target and the positions of the other agents targets.
-
-        - A 3D array (map_height, map_width, 4) wtih
-            - first channel containing the agents position and direction
-            - second channel containing the other agents positions and diretions
-            - third channel containing agent malfunctions
-            - fourth channel containing agent fractional speeds
     """
 
     def __init__(self):
-        self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
     def set_env(self, env: Environment):
         super().set_env(env)
-        self.observation_space = [4, self.env.height, self.env.width]
 
     def reset(self):
         self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
@@ -535,22 +517,21 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 self.rail_obs[i, j] = np.array(bitlist)
 
     def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
+
         obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
-        agents = self.env.agents
-        agent = agents[handle]
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1
 
-        agent_pos = agents[handle].position
-        obs_agents_state[agent_pos][0] = agents[handle].direction
+        agent = self.env.agents[handle]
+        obs_agents_state[agent.position][0] = agent.direction
         obs_targets[agent.target][0] = 1
 
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
-                obs_agents_state[agent2.position][1] = agent2.direction
-                obs_targets[agent2.target][1] = 1
-            obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction']
-            obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed']
+        for i in range(len(self.env.agents)):
+            other_agent = self.env.agents[i]
+            if i != handle:
+                obs_agents_state[other_agent.position][1] = other_agent.direction
+                obs_targets[other_agent.target][1] = 1
+            obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
+            obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
 
         return self.rail_obs, obs_agents_state, obs_targets
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 6fefc3f4d76b27744356f4c378be678303cf234c..98fa4d2e4b557edced94c6eb3692f4949abe9221 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -189,7 +189,6 @@ class RailEnv(Environment):
         self.distance_map = DistanceMap(self.agents, self.height, self.width)
 
         self.action_space = [1]
-        self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
         # Stochastic train malfunctioning parameters
         if stochastic_data is not None:
@@ -300,7 +299,6 @@ class RailEnv(Environment):
 
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
-        self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
         self.distance_map.reset(self.agents, self.rail)
 
         # Return the new observation vectors for each agent
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index d2663916a17a70597d10e489da7aead4f8932dc4..0d6d309765690b1f95c681d7d109a13071d7f86b 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -41,7 +41,9 @@ def test_global_obs():
 
     # If this assertion is wrong, it means that the observation returned
     # places the agent on an empty cell
-    assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
+    obs_agents_state = global_obs[0][1]
+    obs_agents_state = obs_agents_state + 1
+    assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
 
 
 def _step_along_shortest_path(env, obs_builder, rail):
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 7ee0fd4aadf72b12b591259a71af8b408145418f..f4ab68bc45a82b8f196fcfbebb26fd68a36c37a4 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -281,9 +281,10 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     # get the trees to test
     obs_builder: TreeObsForRailEnv = env.obs_builder
     pp = pprint.PrettyPrinter(indent=4)
-    tree_0 = obs_builder.unfold_observation_tree(observations[0])
-    tree_1 = obs_builder.unfold_observation_tree(observations[1])
-    pp.pprint(tree_0)
+    tree_0 = observations[0]
+    tree_1 = observations[1]
+    env.obs_builder.util_print_obs_subtree(tree_0)
+    env.obs_builder.util_print_obs_subtree(tree_1)
 
     # check the expectations
     expected_conflicts_0 = [('F', 'R')]
@@ -292,11 +293,18 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
 
 
-def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
-    assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
+def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
+    assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
     for a_1 in obs_builder.tree_explorted_actions_char:
-        conflict = tree_0[a_1][''][8]
-        assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
+        if tree.childs[a_1] == -np.inf:
+            assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
+            continue
+        else:
+            conflict = tree.childs[a_1].num_agents_opposite_direction
+            assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
         for a_2 in obs_builder.tree_explorted_actions_char:
-            conflict = tree_0[a_1][a_2][''][8]
-            assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
+            if tree.childs[a_1].childs[a_2] == -np.inf:
+                assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
+            else:
+                conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
+                assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 86c206097648078b33531fb1bad3b4f091b25ba4..cc61150325ee86710ea3ee3820d8386c7b926da6 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -22,7 +22,6 @@ class SingleAgentNavigationObs(ObservationBuilder):
 
     def __init__(self):
         super().__init__()
-        self.observation_space = [3]
 
     def reset(self):
         pass
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7213560f9e9873ea4488b96d30223bab8128b37b
--- /dev/null
+++ b/tests/test_global_observation.py
@@ -0,0 +1,64 @@
+import numpy as np
+
+from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+
+
+def test_get_global_observation():
+    np.random.seed(1)
+    number_of_agents = 20
+
+    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
+                       'malfunction_rate': 30,  # Rate of malfunction occurence
+                       'min_duration': 3,  # Minimal duration of malfunction
+                       'max_duration': 20  # Max duration of malfunction
+                       }
+
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
+
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=sparse_rail_generator(num_cities=25,
+                                                       # Number of cities in map (where train stations are)
+                                                       num_intersections=10,
+                                                       # Number of intersections (no start / target)
+                                                       num_trainstations=50,  # Number of possible start/targets on map
+                                                       min_node_dist=3,  # Minimal distance of nodes
+                                                       node_radius=4,  # Proximity of stations to city center
+                                                       num_neighb=4,
+                                                       # Number of connections to other cities/intersections
+                                                       seed=15,  # Random seed
+                                                       grid_mode=True,
+                                                       enhance_intersection=False
+                                                       ),
+                  schedule_generator=sparse_schedule_generator(speed_ration_map),
+                  number_of_agents=number_of_agents, stochastic_data=stochastic_data,  # Malfunction data generator
+                  obs_builder_object=GlobalObsForRailEnv())
+
+    obs, all_rewards, done, _ = env.step({0: 0})
+
+    for i in range(len(env.agents)):
+        obs_agents_state = obs[i][1]
+        obs_targets = obs[i][2]
+
+        nr_agents = np.count_nonzero(obs_targets[:, :, 0])
+        nr_agents_other = np.count_nonzero(obs_targets[:, :, 1])
+        assert nr_agents == 1
+        assert nr_agents_other == (number_of_agents - 1)
+
+        # since the array is initialized with -1 add one in order to used np.count_nonzero
+        obs_agents_state += 1
+        obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0])
+        obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1])
+        obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2])
+        obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3])
+        assert obs_agents_state_0 == 1
+        assert obs_agents_state_1 == (number_of_agents - 1)
+        assert obs_agents_state_2 == number_of_agents
+        assert obs_agents_state_3 == number_of_agents
+