diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index fd71bb989e37980b5d28294eb7a65550b6935435..7271f6823222ca8e94257140418e03de456a7049 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -127,53 +127,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
         neighbors = []
 
-        for direction in range(4):
-            new_cell = self._new_position(position, (direction+2) % 4)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-                # Check if the two cells are connected by a valid transition
-                transitionValid = False
-                for orientation in range(4):
-                    moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
-                    if moves[direction]:
-                        transitionValid = True
-                        break
-
-                if not transitionValid:
-                    continue
-
-                # Check if a transition in direction node[2] is possible if an agent
-                # lands in the current cell with orientation `direction'; this only
-                # applies to cells that are not dead-ends!
-                directionMatch = True
-                if enforce_target_direction >= 0:
-                    directionMatch = self.env.rail.get_transition(
-                        (new_cell[0], new_cell[1], direction), enforce_target_direction)
-
-                # If transition is found to invalid, check if perhaps it
-                # is a dead-end, in which case the direction of movement is rotated
-                # 180 degrees (moving forward turns the agents and makes it step in the previous cell)
-                if not directionMatch:
-                    # If cell is a dead-end, append previous node with reversed
-                    # orientation!
-                    nbits = 0
-                    tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
-                    while tmp > 0:
-                        nbits += (tmp & 1)
-                        tmp = tmp >> 1
-                    if nbits == 1:
-                        # Dead-end!
-                        # Check if transition is possible in new_cell
-                        # with orientation (direction+2)%4 in direction `direction'
-                        directionMatch = directionMatch or self.env.rail.get_transition(
-                            (new_cell[0], new_cell[1], (direction+2) % 4), direction)
-
-                if transitionValid and directionMatch:
-                    new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction],
-                                       current_distance+1)
-                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
-                    self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance
-
         possible_directions = [0, 1, 2, 3]
         if enforce_target_direction >= 0:
             # The agent must land into the current cell with orientation `enforce_target_direction'.
@@ -263,7 +216,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         #3: 1 if another agent is detected between the previous node and the current one.
 
-        #4:
+        #4: distance of agent to the current branch node
 
         #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
             branch.
@@ -286,6 +239,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Root node - current position
         observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        root_observation = observation[:]
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
@@ -293,7 +247,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
-                branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1)
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
                 observation = observation + branch_observation
             else:
                 num_cells_to_fill_in = 0
@@ -305,7 +259,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         return observation
 
-    def _explore_branch(self, handle, position, direction, depth):
+    def _explore_branch(self, handle, position, direction, root_observation, depth):
         """
         Utility function to compute tree-based observations.
         """
@@ -323,10 +277,11 @@ class TreeObsForRailEnv(ObservationBuilder):
                                  # to land here
         last_isTarget = False
 
-        visited = set([position[0], position[1], direction])
+        visited = set()
 
         other_agent_encountered = False
         other_target_encountered = False
+        num_steps = 1
         while exploring:
             # #############################
             # #############################
@@ -345,6 +300,8 @@ class TreeObsForRailEnv(ObservationBuilder):
             if (position[0], position[1], direction) in visited:
                 last_isTerminal = True
                 break
+            visited.add((position[0], position[1], direction))
+
 
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
             if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
@@ -377,6 +334,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                         if cell_transitions[i]:
                             position = self._new_position(position, i)
                             direction = i
+                            num_steps += 1
                             break
 
             elif num_transitions > 0:
@@ -386,11 +344,10 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             elif num_transitions == 0:
                 # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
+                print("WRONG CELL TYPE detected in tree-search (0 transitions possible)")
                 last_isTerminal = True
                 break
 
-            visited.add((position[0], position[1], direction))
-
         # `position' is either a terminal node or a switch
 
         observation = []
@@ -403,25 +360,27 @@ class TreeObsForRailEnv(ObservationBuilder):
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           0,
+                           root_observation[3]+num_steps,
                            0]
 
         elif last_isTerminal:
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           0,
+                           np.inf,
                            np.inf]
         else:
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           0,
+                           root_observation[3]+num_steps,
                            self.distance_map[handle, position[0], position[1], direction]]
 
         # #############################
         # #############################
 
+        new_root_observation = observation[:]
+
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
@@ -431,14 +390,14 @@ class TreeObsForRailEnv(ObservationBuilder):
                 # it back
                 new_cell = self._new_position(position, (branch_direction+2) % 4)
 
-                branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1)
+                branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, new_root_observation, depth+1)
                 observation = observation + branch_observation
 
             elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
                                                                 branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
-                branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, new_root_observation, depth+1)
                 observation = observation + branch_observation
 
             else:
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 37a97ff80beff3a2d2abc40f297d456021b071f3..cf206033d7aab65571b18de0c95d729f5d09c65c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -486,6 +486,8 @@ class RailEnv(Environment):
         for handle in self.agents_handles:
             self.dones[handle] = False
 
+        # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
+        # agent's orientations that allow a valid solution.
         re_generate = True
         while re_generate:
             valid_positions = []