From 46a422749846c074c49a39ae4b2bd51ddae052f8 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 14 Jun 2019 00:58:04 +0200
Subject: [PATCH] added direction to the conflict detection

---
 examples/training_example.py  |  2 +-
 flatland/envs/observations.py | 26 +++++++++++++++-----------
 flatland/envs/predictions.py  |  2 +-
 3 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index 218f133e..ad8396a5 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -16,7 +16,7 @@ env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
               obs_builder_object=TreeObservation,
-              number_of_agents=2)
+              number_of_agents=1)
 
 
 # Import your own Agent or use RLlib to train agents on Flatland
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 439d538b..b4925024 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -331,17 +331,21 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             # Register possible conflict
             if self.predictor and num_steps < self.max_prediction_depth:
-                if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps],
-                                                                                   handle):
-                    potential_conflict = 1
-                if coordinate_to_position(self.env.width, [position]) in np.delete(
-                    self.predicted_pos[max(0, num_steps - 1)],
-                    handle):
-                    potential_conflict = 1
-                if coordinate_to_position(self.env.width, [position]) in np.delete(
-                    self.predicted_pos[min(self.max_prediction_depth - 1, num_steps + 1)],
-                    handle):
-                    potential_conflict = 1
+                int_position = coordinate_to_position(self.env.width, [position])
+                pre_step = max(0, num_steps - 1)
+                post_step = min(self.max_prediction_depth - 1, num_steps + 1)
+                if int_position in np.delete(self.predicted_pos[num_steps], handle):
+                    conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)[0][0]
+                    if direction != self.predicted_dir[num_steps][conflicting_agent]:
+                        potential_conflict = 1
+                elif int_position in np.delete(self.predicted_pos[pre_step], handle):
+                    conflicting_agent = np.where(np.delete(self.predicted_pos[pre_step], handle) == int_position)[0][0]
+                    if direction != self.predicted_dir[pre_step][conflicting_agent]:
+                        potential_conflict = 1
+                elif int_position in np.delete(self.predicted_pos[post_step], handle):
+                    conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)[0][0]
+                    if direction != self.predicted_dir[post_step][conflicting_agent]:
+                        potential_conflict = 1
 
             if position in self.location_has_target and position != agent.target:
                 if num_steps < other_target_encountered:
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 366acfeb..f5f559a1 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -135,7 +135,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                             if target_dist < min_dist:
                                 min_dist = target_dist
                                 new_direction = direct
-                                new_position = self._new_position(agent.position, new_direction)
+                    new_position = self._new_position(agent.position, new_direction)
 
                 agent.position = new_position
                 agent.direction = new_direction
-- 
GitLab