From 245c9cc28044bd2ee3f6140e70ad7b25544bb182 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 25 Jul 2019 15:29:07 -0400
Subject: [PATCH] At switches now conflicts are only detected if agent has
 option to choose head on colliding path

---
 flatland/core/env_observation_builder.py |  1 +
 flatland/envs/observations.py            | 16 ++++++++++++----
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 060785f5..4acdf16f 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -74,6 +74,7 @@ class ObservationBuilder:
         direction[agent.direction] = 1
         return direction
 
+
 class DummyObservationBuilder(ObservationBuilder):
     """
     DummyObservationBuilder class which returns dummy observations
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 498d98ee..8b0c94f1 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -270,6 +270,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
 
         visited = set()
+
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # If only one transition is possible, the tree is oriented with this transition as the forward branch.
@@ -289,6 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 # add cells filled with infinity if no transition is possible
                 observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
         self.env.dev_obs_dict[handle] = visited
+
         return observation
 
     def _num_cells_to_fill_in(self, remaining_depth):
@@ -362,11 +364,12 @@ class TreeObsForRailEnv(ObservationBuilder):
                     pre_step = max(0, tot_dist - 1)
                     post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
 
-                    # Look for conflicting paths at distance num_step
+                    # Look for conflicting paths at distance tot_dist
                     if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[tot_dist][ca] and cell_transitions[self._reverse_dir(
+                                self.predicted_dir[tot_dist][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -375,7 +378,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                     elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir(
+                                self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -384,7 +388,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                     elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
+                                self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -566,6 +571,9 @@ class TreeObsForRailEnv(ObservationBuilder):
         if self.predictor:
             self.predictor._set_env(self.env)
 
+    def _reverse_dir(self, direction):
+        return int((direction + 2) % 4)
+
 
 class GlobalObsForRailEnv(ObservationBuilder):
     """
-- 
GitLab