From c0939ad62d7c7e93ddc615c6a0ffe24153abbad0 Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Fri, 19 Apr 2019 14:28:11 +0200
Subject: [PATCH] added features to treesearch

---
 examples/temporary_example.py            |  8 ++--
 flatland/core/env_observation_builder.py | 52 +++++++++++++++++++-----
 2 files changed, 46 insertions(+), 14 deletions(-)

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 9d6046ec..03b5ebd0 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -26,14 +26,14 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
 env = RailEnv(width=6,
               height=2,
               rail_generator=rail_from_manual_specifications_generator(specs),
-              number_of_agents=1,
+              number_of_agents=2,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
 handle = env.get_agent_handles()
 
-env.agents_position = [[1, 4]]
-env.agents_target = [[1, 1]]
-env.agents_direction = [1]
+env.agents_position[0] = [1, 4]
+env.agents_target[0] = [1, 1]
+env.agents_direction[0] = 1
 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
 env.obs_builder.reset()
 
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index e2e74b54..24f213ab 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -69,6 +69,11 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i in range(self.env.number_of_agents):
             self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
 
+        # Update local lookup table for all agents' target locations
+        self.location_has_target = {}
+        for loc in self.env.agents_target:
+            self.location_has_target[(loc[0],loc[1])] = 1
+
     def _distance_map_walker(self, position, target_nr):
         """
         Utility function to compute distance maps from each cell in the rail network (and each possible
@@ -210,13 +215,14 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         #1:
 
-        #2:
+        #2: 1 if a target of another agent is detected between the previous node and the current one.
 
-        #3:
+        #3: 1 if another agent is detected between the previous node and the current one.
 
         #4:
 
-        #5: minimum distance from node to the agent's target
+        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
+            branch.
 
         Missing/padding nodes are filled in with -inf (truncated).
         Missing values in present node are filled in with +inf (truncated).
@@ -226,6 +232,11 @@ class TreeObsForRailEnv(ObservationBuilder):
         In case the target node is reached, the values are [0, 0, 0, 0, 0].
         """
 
+        # Update local lookup table for all agents' positions
+        self.location_has_agent = {}
+        for loc in self.env.agents_position:
+            self.location_has_agent[(loc[0], loc[1])] = 1
+
         position = self.env.agents_position[handle]
         orientation = self.env.agents_direction[handle]
 
@@ -264,25 +275,31 @@ class TreeObsForRailEnv(ObservationBuilder):
         exploring = True
         last_isSwitch = False
         last_isDeadEnd = False
-        # TODO: last_isTerminal = False  # dead-end
-        # TODO: last_isTarget = False
+        # TODO: last_isTerminal = False  # wrong cell encountered
+        last_isTarget = False
+
+        other_agent_encountered = False
+        other_target_encountered = False
         while exploring:
             # #############################
             # #############################
             # Modify here to compute any useful data required to build the end node's features. This code is called
             # for each cell visited between the previous branching node and the next switch / target / dead-end.
 
-            # TODO: update the current variables according to the current cell in the path
-            # (store info about other agents and targets)
+            if position in self.location_has_agent:
+                other_agent_encountered = True
+
+            if position in self.location_has_target:
+                other_target_encountered = True
+
 
-            # TODO: [[[for efficiency, [make dict for hashed-lookup of coords] -- do it in the reset function!]]]
 
             # #############################
             # #############################
 
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
             if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
-                # TODO: last_isTarget = True
+                last_isTarget = True
                 break
 
             cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
@@ -323,9 +340,24 @@ class TreeObsForRailEnv(ObservationBuilder):
         # #############################
         # Modify here to append new / different features for each visited cell!
 
-        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], direction]]
+        if last_isTarget:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           0,
+                           0]
+
+        else:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           0,
+                           self.distance_map[handle, position[0], position[1], direction]]
+
+
         # TODO:
 
+
         # #############################
         # #############################
 
-- 
GitLab