From 51c747b6b4efa36bf2e326dbb3c3751aaa23ba7b Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Mon, 20 May 2019 15:25:15 +0200
Subject: [PATCH] added direction of other agents in global observation

---
 flatland/envs/observations.py | 27 ++++++++++++++++-----------
 1 file changed, 16 insertions(+), 11 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index fa77b58a..0fd94a4e 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -203,7 +203,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
         # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
-        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position,direc agent.direction)]]
         root_observation = observation[:]
         visited = set()
         # Start from the current orientation, and see which transitions are available;
@@ -478,11 +478,15 @@ class GlobalObsForRailEnv(ObservationBuilder):
         - transition map array with dimensions (env.height, env.width, 16),
           assuming 16 bits encoding of transitions.
 
-        - Four 2D arrays containing respectively the position of the given agent,
-          the position of its target, the positions of the other agents and of
-          their target.
+        - Three 2D arrays (map_height, map_width, 3) containing respectively the position of the given agent,
+          the position of its target and the positions of the other agents targets.
+
+        - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions
+          of the other agents at their position coordinates.
+
+        - A 4 elements array with one of encoding of the direction of the agent of interest.
+
 
-        - A 4 elements array with one of encoding of the direction.
     """
 
     def __init__(self):
@@ -503,21 +507,22 @@ class GlobalObsForRailEnv(ObservationBuilder):
         #     self.targets[target_pos] += 1
 
     def get(self, handle):
-        obs = np.zeros((4, self.env.height, self.env.width))
+        obs_map_state = np.zeros((self.env.height, self.env.width, 3))
+        obs_other_agents_state = np.zeros((self.env.height, self.env.width, 4))
         agents = self.env.agents
         agent = agents[handle]
 
         agent_pos = agents[handle].position
-        obs[0][agent_pos] += 1
-        obs[1][agent.target] += 1
+        obs_map_state[agent_pos][0] += 1
+        obs_map_state[agent.target][1] += 1
 
         for i in range(len(agents)):
             if i != handle:  # TODO: handle used as index...?
                 agent2 = agents[i]
-                obs[3][agent2.position] += 1
-                obs[2][agent2.target] += 1
+                obs_other_agents_state[agent2.position][agent2.direction] = 1
+                obs_map_state[agent2.target][2] += 1
 
         direction = np.zeros(4)
         direction[agent.direction] = 1
 
-        return self.rail_obs, obs, direction
+        return self.rail_obs, obs_map_state, obs_other_agents_state,  direction
-- 
GitLab