diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index edd9cae7d9c632a0564d31607c1d661bdb42641c..a6fbae6d0d271f47e98d08262c7fbc2801b7142d 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -82,8 +82,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
         # Returns max distance to target, from the farthest away node, while filling in distance_map
 
-        for ori in range(4):
-            self.distance_map[target_nr, position[0], position[1], ori] = 0
+        self.distance_map[target_nr, position[0], position[1], :] = 0
 
         # Fill in the (up to) 4 neighboring nodes
         # nodes_queue = []  # list of tuples (row, col, direction, distance);
@@ -237,14 +236,18 @@ class TreeObsForRailEnv(ObservationBuilder):
         position = self.env.agents_position[handle]
         orientation = self.env.agents_direction[handle]
         possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation))
-
+        num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
         observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
         root_observation = observation[:]
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
-        # TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible.
+        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
+        # TODO: Test if this works as desired!
+        if num_transitions == 1:
+            orientation == np.argmax(possible_transitions)
+
         for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(position, branch_direction)