From c5345c17d84b02d7dd3bbeebc1edf6839f4202e9 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 19 Jul 2019 10:07:37 -0400
Subject: [PATCH] Updated potential conflict detection. Agents now get warned
 about conflicts when other agents have reached their target and are in the
 way.

---
 flatland/envs/observations.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 3e8583d9..2887610b 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -358,21 +358,26 @@ class TreeObsForRailEnv(ObservationBuilder):
                     if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
                         for ca in conflicting_agent[0]:
-
                             if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
+                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                                potential_conflict = tot_dist
                     # Look for opposing paths at distance num_step-1
                     elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
                         for ca in conflicting_agent[0]:
                             if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
+                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                                potential_conflict = tot_dist
                     # Look for opposing paths at distance num_step+1
                     elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                         for ca in conflicting_agent[0]:
                             if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
+                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                                potential_conflict = tot_dist
 
             if position in self.location_has_target and position != agent.target:
                 if tot_dist < other_target_encountered:
-- 
GitLab