diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 4d6da5b5608acd5284e681f376a2d53cf1bd535b..34a5bb539e5d6eb86dd1782d70cc6b406988fee2 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -262,7 +262,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # until no transitions are possible along the current direction (i.e., dead-ends)
         # We treat dead-ends as nodes, instead of going back, to avoid loops
         exploring = True
-        # TODO: last_isSwitch = False
+        last_isSwitch = False
         # TODO: last_isTerminal = False  # dead-end
         # TODO: last_isTarget = False
         while exploring:
@@ -306,7 +306,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             elif num_transitions > 0:
                 # Switch detected
-                # TODO: last_isSwitch = True
+                last_isSwitch = True
                 break
 
             elif num_transitions == 0:
@@ -331,7 +331,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
-            if self.env.rail.get_transition((position[0], position[1], direction), branch_direction):
+            if last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
                 branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)