diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index b097b5df1e65a249d850e55e164ea5e8124767f9..0fdec364ecc3d8415b27e2101c339d6f49e89022 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -205,7 +205,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
         observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
         root_observation = observation[:]
-
+        visited = set()
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # If only one transition is possible, the tree is oriented with this transition as the forward branch.
@@ -219,9 +219,10 @@ class TreeObsForRailEnv(ObservationBuilder):
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
 
-                branch_observation, visited = self._explore_branch(handle, new_cell, branch_direction, root_observation,
+                branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, root_observation,
                                                                    1)
                 observation = observation + branch_observation
+                visited = visited.union(branch_visited)
             else:
                 num_cells_to_fill_in = 0
                 pow4 = 1