From ff2ea11da7abccddd42b5dd7920e7875dc3304af Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 15 Jun 2019 13:06:26 +0200
Subject: [PATCH] updated how distance from agent is measured to detect
 conflicts!

---
 flatland/envs/observations.py | 55 +++++++++++++++++++----------------
 1 file changed, 30 insertions(+), 25 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 931fabc7..007f6939 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -273,7 +273,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
-                    self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
+                    self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1)
                 observation = observation + branch_observation
                 visited = visited.union(branch_visited)
             else:
@@ -286,7 +286,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.env.dev_obs_dict[handle] = visited
         return observation
 
-    def _explore_branch(self, handle, position, direction, root_observation, depth):
+    def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
         """
         Utility function to compute tree-based observations.
         """
@@ -332,26 +332,31 @@ class TreeObsForRailEnv(ObservationBuilder):
             # Register possible conflict
             if self.predictor and num_steps < self.max_prediction_depth:
                 int_position = coordinate_to_position(self.env.width, [position])
-                pre_step = max(0, num_steps - 1)
-                post_step = min(self.max_prediction_depth - 1, num_steps + 1)
-                # Look for opposing paths at distance num_step
-                if int_position in np.delete(self.predicted_pos[num_steps], handle):
-                    conflicting_agent = np.where(np.delete(self.predicted_pos[num_steps], handle) == int_position)
-                    for ca in conflicting_agent:
-                        if direction != self.predicted_dir[num_steps][ca[0]]:
-                            potential_conflict = 1
-                # Look for opposing paths at distance num_step-1
-                elif int_position in np.delete(self.predicted_pos[pre_step], handle):
-                    conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
-                    for ca in conflicting_agent:
-                        if direction != self.predicted_dir[pre_step][ca[0]]:
-                            potential_conflict = 1
-                # Look for opposing paths at distance num_step+1
-                elif int_position in np.delete(self.predicted_pos[post_step], handle):
-                    conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
-                    for ca in conflicting_agent:
-                        if direction != self.predicted_dir[post_step][ca[0]]:
-                            potential_conflict = 1
+                if tot_dist < self.max_prediction_depth:
+                    pre_step = max(0, tot_dist - 1)
+                    post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
+
+                    # Look for opposing paths at distance num_step
+                    if int_position in np.delete(self.predicted_pos[tot_dist], handle):
+                        conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
+                        for ca in conflicting_agent:
+                            if direction != self.predicted_dir[tot_dist][ca[0]]:
+                                potential_conflict = 1
+                                # print("Potential Conflict",position,handle,ca[0],tot_dist,depth)
+                    # Look for opposing paths at distance num_step-1
+                    elif int_position in np.delete(self.predicted_pos[pre_step], handle):
+                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
+                        for ca in conflicting_agent:
+                            if direction != self.predicted_dir[pre_step][ca[0]]:
+                                potential_conflict = 1
+                                # print("Potential Conflict", position,handle,ca[0],pre_step,depth)
+                    # Look for opposing paths at distance num_step+1
+                    elif int_position in np.delete(self.predicted_pos[post_step], handle):
+                        conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
+                        for ca in conflicting_agent:
+                            if direction != self.predicted_dir[post_step][ca[0]]:
+                                potential_conflict = 1
+                                # print("Potential Conflict", position,handle,ca[0],post_step,depth)
 
             if position in self.location_has_target and position != agent.target:
                 if num_steps < other_target_encountered:
@@ -395,7 +400,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     direction = np.argmax(cell_transitions)
                     position = self._new_position(position, direction)
                     num_steps += 1
-
+                    tot_dist += 1
             elif num_transitions > 0:
                 # Switch detected
                 last_isSwitch = True
@@ -499,7 +504,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           (branch_direction + 2) % 4,
-                                                                          new_root_observation,
+                                                                          new_root_observation, tot_dist + 1,
                                                                           depth + 1)
                 observation = observation + branch_observation
                 if len(branch_visited) != 0:
@@ -509,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           branch_direction,
-                                                                          new_root_observation,
+                                                                          new_root_observation, tot_dist + 1,
                                                                           depth + 1)
                 observation = observation + branch_observation
                 if len(branch_visited) != 0:
-- 
GitLab