diff --git a/examples/training_example.py b/examples/training_example.py
index 5391dbbe4d1709eb7727a6f6fa62612f01439ce0..8f7d18f55ec63f2ed2d57584817d14f8bc722259 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -18,7 +18,7 @@ env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
               obs_builder_object=TreeObservation,
-              number_of_agents=10)
+              number_of_agents=3)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4fd6bd655095bbf6279280cd0d4c040bfc7fd9f6..4158675cae63394ed768bfc36aaef9cd5f44da7e 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -376,7 +376,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 self._reverse_dir(
                                     self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and predicted_time < potential_conflict:
+                            if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
                     # Look for conflicting paths at distance num_step-1