From 97a906fb61175233f3b75e87466bb0ae00bf4fca Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Fri, 17 May 2019 17:17:34 +0200
Subject: [PATCH] first implementation of local observation, need testing

---
 flatland/envs/observations.py | 78 ++++++++++++++++++++++++++++++++++-
 1 file changed, 77 insertions(+), 1 deletion(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 0fdec364..d8d532c5 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -482,7 +482,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
           the position of its target, the positions of the other agents and of
           their target.
 
-        - A 4 elements array with one of encoding of the direction.
+        - A 4 elements array with one hot encoding of the direction.
     """
 
     def __init__(self):
@@ -518,3 +518,79 @@ class GlobalObsForRailEnv(ObservationBuilder):
         direction[agent.direction] = 1
 
         return self.rail_obs, obs, direction
+
+
+class LocalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array of the local environment around the given agent,
+          with dimensions (2*view_radius + 1, 2*view_radius + 1, 16),
+          assuming 16 bits encoding of transitions.
+
+        - Three 2D arrays containing respectively, if they are in the agent's vision range,
+          its target position, the positions of the other agents and of their target.
+
+        - A 4 elements array with one hot encoding of the direction.
+    """
+
+    def __init__(self, view_radius):
+        """
+        :param view_radius:
+        """
+        super(LocalObsForRailEnv, self).__init__()
+        self.view_radius = view_radius
+
+    def reset(self):
+        # We build the transition map with a view_radius empty cells expansion on each side.
+        # This helps to collect the local transition map view when the agent is close to a border.
+
+        self.rail_obs = np.zeros((self.env.height + 2*self.view_radius,
+                                  self.env.width + 2*self.view_radius, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array(
+                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+
+    def get(self, handle):
+        agents = self.env.agents
+        agent = agents[handle]
+
+        # left_offset = max(0, agent.position[1] - 1 - self.view_radius)
+        # right_offset = min(self.env.width, agent.position[1] + 1 + self.view_radius)
+        # top_offset = max(0, agent.position[0] - 1 - self.view_radius)
+        # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
+
+        local_rail_obs = self.rail_obs[agent.position: agent.position+2*self.view_radius +1,
+                         agent.position:agent.position+2*self.view_radius +1]
+
+        obs = np.zeros((3, 2*self.view_radius +1, 2*self.view_radius + 1))
+
+        def relative_pos(pos):
+            return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
+
+        def is_in(rel_pos):
+            return abs(rel_pos) <= self.view_radius
+
+        target_rel_pos = relative_pos(agent.target)
+        if is_in(target_rel_pos):
+            obs[0][self.view_radius + 1 + np.array(target_rel_pos)] += 1
+
+        for i in range(len(agents)):
+            if i != handle:  # TODO: handle used as index...?
+                agent2 = agents[i]
+
+                agent_2_rel_pos = relative_pos(agent2.position)
+                if is_in(agent_2_rel_pos):
+                    obs[1][self.view_radius + 1 + np.array(agent_2_rel_pos)] += 1
+
+                target_rel_pos_2 = relative_pos(agent2.position)
+                if is_in(target_rel_pos_2):
+                    obs[2][self.view_radius + 1 + np.array(target_rel_pos_2)] += 1
+
+        direction = np.zeros(4)
+        direction[agent.direction] = 1
+
+        return local_rail_obs, obs, direction
+
-- 
GitLab