diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index fd71bb989e37980b5d28294eb7a65550b6935435..86485ec2c068ea410d5d27997f6f037d3aab6c23 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -12,8 +12,6 @@ import numpy as np
 
 from collections import deque
 
-# TODO: add docstrings, pylint, etc...
-
 
 class ObservationBuilder:
     """
@@ -127,53 +125,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
         neighbors = []
 
-        for direction in range(4):
-            new_cell = self._new_position(position, (direction+2) % 4)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-                # Check if the two cells are connected by a valid transition
-                transitionValid = False
-                for orientation in range(4):
-                    moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
-                    if moves[direction]:
-                        transitionValid = True
-                        break
-
-                if not transitionValid:
-                    continue
-
-                # Check if a transition in direction node[2] is possible if an agent
-                # lands in the current cell with orientation `direction'; this only
-                # applies to cells that are not dead-ends!
-                directionMatch = True
-                if enforce_target_direction >= 0:
-                    directionMatch = self.env.rail.get_transition(
-                        (new_cell[0], new_cell[1], direction), enforce_target_direction)
-
-                # If transition is found to invalid, check if perhaps it
-                # is a dead-end, in which case the direction of movement is rotated
-                # 180 degrees (moving forward turns the agents and makes it step in the previous cell)
-                if not directionMatch:
-                    # If cell is a dead-end, append previous node with reversed
-                    # orientation!
-                    nbits = 0
-                    tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
-                    while tmp > 0:
-                        nbits += (tmp & 1)
-                        tmp = tmp >> 1
-                    if nbits == 1:
-                        # Dead-end!
-                        # Check if transition is possible in new_cell
-                        # with orientation (direction+2)%4 in direction `direction'
-                        directionMatch = directionMatch or self.env.rail.get_transition(
-                            (new_cell[0], new_cell[1], (direction+2) % 4), direction)
-
-                if transitionValid and directionMatch:
-                    new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction],
-                                       current_distance+1)
-                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
-                    self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance
-
         possible_directions = [0, 1, 2, 3]
         if enforce_target_direction >= 0:
             # The agent must land into the current cell with orientation `enforce_target_direction'.
@@ -263,7 +214,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         #3: 1 if another agent is detected between the previous node and the current one.
 
-        #4:
+        #4: distance of agent to the current branch node
 
         #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
             branch.
@@ -286,6 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Root node - current position
         observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        root_observation = observation[:]
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
@@ -293,7 +245,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
-                branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1)
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
                 observation = observation + branch_observation
             else:
                 num_cells_to_fill_in = 0
@@ -305,7 +257,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         return observation
 
-    def _explore_branch(self, handle, position, direction, depth):
+    def _explore_branch(self, handle, position, direction, root_observation, depth):
         """
         Utility function to compute tree-based observations.
         """
@@ -319,14 +271,14 @@ class TreeObsForRailEnv(ObservationBuilder):
         exploring = True
         last_isSwitch = False
         last_isDeadEnd = False
-        last_isTerminal = False  # wrong cell encountered OR cycle encountered;  either way, we don't want the agent
-                                 # to land here
+        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
         last_isTarget = False
 
-        visited = set([position[0], position[1], direction])
+        visited = set()
 
         other_agent_encountered = False
         other_target_encountered = False
+        num_steps = 1
         while exploring:
             # #############################
             # #############################
@@ -345,6 +297,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if (position[0], position[1], direction) in visited:
                 last_isTerminal = True
                 break
+            visited.add((position[0], position[1], direction))
 
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
             if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
@@ -377,6 +330,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                         if cell_transitions[i]:
                             position = self._new_position(position, i)
                             direction = i
+                            num_steps += 1
                             break
 
             elif num_transitions > 0:
@@ -386,11 +340,10 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             elif num_transitions == 0:
                 # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
+                print("WRONG CELL TYPE detected in tree-search (0 transitions possible)")
                 last_isTerminal = True
                 break
 
-            visited.add((position[0], position[1], direction))
-
         # `position' is either a terminal node or a switch
 
         observation = []
@@ -403,25 +356,27 @@ class TreeObsForRailEnv(ObservationBuilder):
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           0,
+                           root_observation[3]+num_steps,
                            0]
 
         elif last_isTerminal:
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           0,
+                           np.inf,
                            np.inf]
         else:
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           0,
+                           root_observation[3]+num_steps,
                            self.distance_map[handle, position[0], position[1], direction]]
 
         # #############################
         # #############################
 
+        new_root_observation = observation[:]
+
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
@@ -431,14 +386,22 @@ class TreeObsForRailEnv(ObservationBuilder):
                 # it back
                 new_cell = self._new_position(position, (branch_direction+2) % 4)
 
-                branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1)
+                branch_observation = self._explore_branch(handle,
+                                                          new_cell,
+                                                          (branch_direction+2) % 4,
+                                                          new_root_observation,
+                                                          depth+1)
                 observation = observation + branch_observation
 
             elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
                                                                 branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
-                branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
+                branch_observation = self._explore_branch(handle,
+                                                          new_cell,
+                                                          branch_direction,
+                                                          new_root_observation,
+                                                          depth+1)
                 observation = observation + branch_observation
 
             else:
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 37a97ff80beff3a2d2abc40f297d456021b071f3..cf206033d7aab65571b18de0c95d729f5d09c65c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -486,6 +486,8 @@ class RailEnv(Environment):
         for handle in self.agents_handles:
             self.dones[handle] = False
 
+        # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
+        # agent's orientations that allow a valid solution.
         re_generate = True
         while re_generate:
             valid_positions = []
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index 44dbfc6d8f14a7293e148f125b47eed66c9ca08d..db264c2975b75f32c2612aa19c0511076460ec6b 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -1,10 +1,11 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import numpy as np
+
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
 from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
-from flatland.utils.rendertools import *
 
 """Tests for `flatland` package."""
 
@@ -45,14 +46,14 @@ def test_global_obs():
         double_switch_south_horizontal_straight, 180)
 
     rail_map = np.array(
-        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
-        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
-        [[dead_end_from_east] + [horizontal_straight] * 2 +
-         [double_switch_north_horizontal_straight] +
-        [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
-        [horizontal_straight] * 2 + [dead_end_from_west]] +
-        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
-        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+               [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+               [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+               [[dead_end_from_east] + [horizontal_straight] * 2 +
+                [double_switch_north_horizontal_straight] +
+                [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+                [horizontal_straight] * 2 + [dead_end_from_west]] +
+               [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+               [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
 
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
@@ -81,17 +82,3 @@ def test_global_obs():
     # If this assertion is wrong, it means that the observation returned
     # places the agent on an empty cell
     assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
-
-
-
-test_global_obs()
-
-
-
-
-
-
-
-
-
-