From c158367768cedbf5c3a53dae115aa6c9b72ffd58 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Sun, 8 Nov 2020 21:24:33 +0100
Subject: [PATCH] FastTreeObs (fix) -> 0.8157

---
 utils/fast_tree_obs.py | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index 0c8af2e..e8df61e 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -161,12 +161,12 @@ class FastTreeObs(ObservationBuilder):
     def _explore(self, handle, new_position, new_direction, depth=0):
         has_opp_agent = 0
         has_same_agent = 0
-        has_switch = 0
+        has_target = 0
         visited = []
 
         # stop exploring (max_depth reached)
         if depth >= self.max_depth:
-            return has_opp_agent, has_same_agent, has_switch, visited
+            return has_opp_agent, has_same_agent, has_target, visited
 
         # max_explore_steps = 100 -> just to ensure that the exploration ends
         cnt = 0
@@ -179,7 +179,7 @@ class FastTreeObs(ObservationBuilder):
                 if self.env.agents[opp_a].direction != new_direction:
                     # opp agent found -> stop exploring. This would be a strong signal.
                     has_opp_agent = 1
-                    return has_opp_agent, has_same_agent, has_switch, visited
+                    return has_opp_agent, has_same_agent, has_target, visited
                 else:
                     # same agent found
                     # the agent can follow the agent, because this agent is still moving ahead and there shouldn't
@@ -199,7 +199,10 @@ class FastTreeObs(ObservationBuilder):
             if agents_near_to_switch:
                 # The exploration was walking on a path where the agent can not decide
                 # Best option would be MOVE_FORWARD -> Skip exploring - just walking
-                return has_opp_agent, has_same_agent, has_switch, visited
+                return has_opp_agent, has_same_agent, has_target, visited
+
+            if self.env.agents[handle].target == new_position:
+                has_target = 1
 
             possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
             if agents_on_switch:
@@ -212,20 +215,20 @@ class FastTreeObs(ObservationBuilder):
                     # --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
                     # we did in the TreeObservation (FLATLAND) ?
                     if possible_transitions[dir_loop] == 1:
-                        hoa, hsa, hs, v = self._explore(handle,
+                        hoa, hsa, ht, v = self._explore(handle,
                                                         get_new_position(new_position, dir_loop),
                                                         dir_loop,
                                                         depth + 1)
                         visited.append(v)
                         has_opp_agent = max(has_opp_agent, hoa)
                         has_same_agent = max(has_same_agent, hsa)
-                        has_switch = max(has_switch, hs)
-                return has_opp_agent, has_same_agent, has_switch, visited
+                        has_target = max(has_target, ht)
+                return has_opp_agent, has_same_agent, has_target, visited
             else:
                 new_direction = fast_argmax(possible_transitions)
                 new_position = get_new_position(new_position, new_direction)
 
-        return has_opp_agent, has_same_agent, has_switch, visited
+        return has_opp_agent, has_same_agent, has_target, visited
 
     def get(self, handle):
         # all values are [0,1]
@@ -296,13 +299,13 @@ class FastTreeObs(ObservationBuilder):
                     if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
                         observation[dir_loop] = int(new_cell_dist < current_cell_dist)
 
-                    has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
+                    has_opp_agent, has_same_agent, has_target, v = self._explore(handle, new_position, branch_direction)
                     visited.append(v)
 
                     observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist))
                     observation[14 + dir_loop] = has_opp_agent
                     observation[18 + dir_loop] = has_same_agent
-                    observation[22 + dir_loop] = has_switch
+                    observation[22 + dir_loop] = has_target
 
         agents_on_switch, \
         agents_near_to_switch, \
-- 
GitLab