From 30d7994df3d1ba86a7bae2ad0b5a4bec8d20494b Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 10 Aug 2019 12:08:59 -0400
Subject: [PATCH] Updated tree observation to take agent speed into account
 when checking for potential conlficts

---
 examples/training_example.py  | 2 +-
 flatland/envs/observations.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index 5391dbbe..8f7d18f5 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -18,7 +18,7 @@ env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
               obs_builder_object=TreeObservation,
-              number_of_agents=10)
+              number_of_agents=3)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4fd6bd65..4158675c 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -376,7 +376,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 self._reverse_dir(
                                     self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and predicted_time < potential_conflict:
+                            if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
                     # Look for conflicting paths at distance num_step-1
-- 
GitLab