diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 7996d03c6923147e9e4d8e1ec168b2fb43a4b410..d9046ec6a937eee04a21e3585608596780fd088c 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -24,17 +24,17 @@ env = RailEnv(width=10,
               height=10,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
               number_of_agents=1)
-
+"""
 env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0),
               number_of_agents=5)
 
-
+"""
 env = RailEnv(width=20,
               height=20,
               rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
-                  ['../env-data/tests/test1.npy']),
+                  ['../env-data/tests/circle.npy']),
               number_of_agents=1)
 
 
@@ -54,7 +54,7 @@ scores = []
 dones_list = []
 action_prob = [0] * 4
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth'))
 
 demo = True
 
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index b5f7ef564694d077fe753bf7857f0a72e3e53a88..f40471f479c3c99ec5800a19073962181be71e0b 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -242,7 +242,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
-            # TODO: check if cell is a curve, then keep branch direction forward instead of left or right
             if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
@@ -395,7 +394,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 observation = observation + branch_observation
 
             elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
-                                                                branch_direction):
+                                                                ):
                 new_cell = self._new_position(position, branch_direction)
 
                 branch_observation = self._explore_branch(handle,