From 33c7a0a180a6a31cd981c4062f117a75e3c9f18c Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Wed, 22 May 2019 14:22:55 +0200
Subject: [PATCH] local observation improved with direction

---
 flatland/envs/observations.py | 38 +++++++++++++++++++++--------------
 1 file changed, 23 insertions(+), 15 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 09739f87..70000131 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -526,7 +526,6 @@ class GlobalObsForRailEnv(ObservationBuilder):
         return self.rail_obs, obs_map_state, obs_other_agents_state, direction
 
 
-
 class LocalObsForRailEnv(ObservationBuilder):
     """
     Gives a global observation of the entire rail environment.
@@ -536,8 +535,11 @@ class LocalObsForRailEnv(ObservationBuilder):
           with dimensions (2*view_radius + 1, 2*view_radius + 1, 16),
           assuming 16 bits encoding of transitions.
 
-        - Three 2D arrays containing respectively, if they are in the agent's vision range,
-          its target position, the positions of the other agents and of their target.
+        - Two 2D arrays containing respectively, if they are in the agent's vision range,
+          its target position, the positions of the other targets.
+
+        - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions
+          of the other agents at their position coordinates, if they are in the agent's vision range.
 
         - A 4 elements array with one hot encoding of the direction.
     """
@@ -555,10 +557,13 @@ class LocalObsForRailEnv(ObservationBuilder):
 
         self.rail_obs = np.zeros((self.env.height + 2*self.view_radius,
                                   self.env.width + 2*self.view_radius, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array(
-                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+        for i in range(self.env.height):
+            for j in range(self.env.width):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
+                # self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array(
+                #     list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
 
     def get(self, handle):
         agents = self.env.agents
@@ -569,20 +574,22 @@ class LocalObsForRailEnv(ObservationBuilder):
         # top_offset = max(0, agent.position[0] - 1 - self.view_radius)
         # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
 
-        local_rail_obs = self.rail_obs[agent.position: agent.position+2*self.view_radius +1,
-                         agent.position:agent.position+2*self.view_radius +1]
+        local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius +1,
+                         agent.position[1]:agent.position[1]+2*self.view_radius +1]
+
+        obs_map_state = np.zeros((2*self.view_radius +1, 2*self.view_radius + 1, 2))
 
-        obs = np.zeros((3, 2*self.view_radius +1, 2*self.view_radius + 1))
+        obs_other_agents_state = np.zeros((2*self.view_radius +1, 2*self.view_radius +1, 4))
 
         def relative_pos(pos):
             return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
 
         def is_in(rel_pos):
-            return abs(rel_pos) <= self.view_radius
+            return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius)
 
         target_rel_pos = relative_pos(agent.target)
         if is_in(target_rel_pos):
-            obs[0][self.view_radius + 1 + np.array(target_rel_pos)] += 1
+            obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1
 
         for i in range(len(agents)):
             if i != handle:  # TODO: handle used as index...?
@@ -590,14 +597,15 @@ class LocalObsForRailEnv(ObservationBuilder):
 
                 agent_2_rel_pos = relative_pos(agent2.position)
                 if is_in(agent_2_rel_pos):
-                    obs[1][self.view_radius + 1 + np.array(agent_2_rel_pos)] += 1
+                    obs_other_agents_state[self.view_radius + agent_2_rel_pos[0],
+                                           self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1
 
                 target_rel_pos_2 = relative_pos(agent2.position)
                 if is_in(target_rel_pos_2):
-                    obs[2][self.view_radius + 1 + np.array(target_rel_pos_2)] += 1
+                    obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
 
         direction = np.zeros(4)
         direction[agent.direction] = 1
 
-        return local_rail_obs, obs, direction
+        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-- 
GitLab