From 3d99e8f3dae8c8cc7b9b80c5a7cff7893d406cb2 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 6 Jun 2019 16:50:01 +0200
Subject: [PATCH] #56 bugfix wrong length of observation vector; #47 added two
 flags whether there is another agent on the same position in the same or
 opposite direction

---
 flatland/envs/observations.py | 69 +++++++++++++++++++++++++----------
 1 file changed, 49 insertions(+), 20 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 18b08f7..4d32faa 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -26,7 +26,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i in range(self.max_depth + 1):
             size += pow4
             pow4 *= 4
-        self.observation_space = [size * 6]
+        self.observation_dim = 7
+        self.observation_space = [size * self.observation_dim]
+        self.location_has_agent = {}
+        self.location_has_agent_direction = {}
 
     def reset(self):
         agents = self.env.agents
@@ -181,8 +184,15 @@ class TreeObsForRailEnv(ObservationBuilder):
         #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
             branch.
 
-        #6: agent direction
+        #6: agent in the same direction
+            1 = agent present same direction
+                (possible future use: number of other agents in the same direction in this branch)
+            0 = no agent present same direction
 
+        #7: agent in the opposite drection
+            1 = agent present other direction than myself (so conflict)
+                (possible future use: number of other agents in other direction in this branch, i.e. number of conflicts)
+            0 = no agent present other direction than myself
 
 
         Missing/padding nodes are filled in with -inf (truncated).
@@ -195,13 +205,15 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Update local lookup table for all agents' positions
         self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
+        self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
         if handle > len(self.env.agents):
             print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
         agent = self.env.agents[handle]  # TODO: handle being treated as index
         possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
-        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
+
         root_observation = observation[:]
         visited = set()
         # Start from the current orientation, and see which transitions are available;
@@ -212,7 +224,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         if num_transitions == 1:
             orientation == np.argmax(possible_transitions)
 
-        # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
@@ -227,7 +238,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 for i in range(self.max_depth):
                     num_cells_to_fill_in += pow4
                     pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
+                observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
         self.env.dev_obs_dict[handle] = visited
         return observation
 
@@ -275,7 +286,6 @@ class TreeObsForRailEnv(ObservationBuilder):
             visited.add((position[0], position[1], direction))
 
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
             if np.array_equal(position, self.env.agents[handle].target):
                 last_isTarget = True
                 break
@@ -297,6 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if not last_isDeadEnd:
                     # Keep walking through the tree along `direction'
                     exploring = True
+                    # convert one-hot encoding to 0,1,2,3
                     direction = np.argmax(cell_transitions)
                     position = self._new_position(position, direction)
                     num_steps += 1
@@ -321,36 +332,53 @@ class TreeObsForRailEnv(ObservationBuilder):
         # #############################
         # Modify here to append new / different features for each visited cell!
         """
+        other_agent_same_direction = \
+            1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
+        other_agent_opposite_direction = \
+            1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
+
         if last_isTarget:
             observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
+                           other_target_encountered,
+                           other_agent_encountered,
                            root_observation[3] + num_steps,
                            0,
-                           direction]
+                           other_agent_same_direction,
+                           other_agent_opposite_direction
+                           ]
 
         elif last_isTerminal:
             observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
+                           other_target_encountered,
+                           other_agent_encountered,
                            np.inf,
                            np.inf,
-                           direction]
+                           other_agent_same_direction,
+                           other_agent_opposite_direction
+                           ]
         else:
             observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
+                           other_target_encountered,
+                           other_agent_encountered,
                            root_observation[3] + num_steps,
                            self.distance_map[handle, position[0], position[1], direction],
-                           direction]
+                           other_agent_same_direction,
+                           other_agent_opposite_direction
+                           ]
         """
+        other_agent_same_direction = \
+            1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
+        other_agent_opposite_direction = \
+            1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
+
         if last_isTarget:
             observation = [0,
                            other_target_encountered,
                            other_agent_encountered,
                            root_observation[3] + num_steps,
                            0,
-                           direction
+                           other_agent_same_direction,
+                           other_agent_opposite_direction
                            ]
 
         elif last_isTerminal:
@@ -359,7 +387,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            other_agent_encountered,
                            np.inf,
                            np.inf,
-                           direction
+                           other_agent_same_direction,
+                           other_agent_opposite_direction
                            ]
         else:
             observation = [0,
@@ -367,7 +396,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            other_agent_encountered,
                            root_observation[3] + num_steps,
                            self.distance_map[handle, position[0], position[1], direction],
-                           direction
+                           other_agent_same_direction,
+                           other_agent_opposite_direction
                            ]
         # #############################
         # #############################
@@ -407,8 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 for i in range(self.max_depth - depth):
                     num_cells_to_fill_in += pow4
                     pow4 *= 4
-                observation = \
-                    observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
+                observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
 
         return observation, visited
 
-- 
GitLab