From 1a9c8c2856d792e4533a75c70c821ff21d0ea3a7 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Wed, 29 May 2019 10:18:31 +0200
Subject: [PATCH] addded global obs direction dependent

---
 flatland/envs/observations.py | 106 +++++++++++++++++++++++++++++-----
 1 file changed, 91 insertions(+), 15 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 96d91579..f5f70aa1 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -483,13 +483,12 @@ class GlobalObsForRailEnv(ObservationBuilder):
         - transition map array with dimensions (env.height, env.width, 16),
           assuming 16 bits encoding of transitions.
 
-        - Three 2D arrays (map_height, map_width, 3) containing respectively the position of the given agent,
-          the position of its target and the positions of the other agents targets.
+        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
+         target and the positions of the other agents targets.
 
-        - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions
+        - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
+          of the direction of the given agent and the 4 second channels containing the positions
           of the other agents at their position coordinates.
-
-        - A 4 elements array with one of encoding of the direction of the agent of interest.
     """
 
     def __init__(self):
@@ -516,30 +515,100 @@ class GlobalObsForRailEnv(ObservationBuilder):
         #     self.targets[target_pos] += 1
 
     def get(self, handle):
-        obs_map_state = np.zeros((self.env.height, self.env.width, 3))
-        obs_other_agents_state = np.zeros((self.env.height, self.env.width, 4))
+        obs_targets = np.zeros((self.env.height, self.env.width, 2))
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
         agents = self.env.agents
         agent = agents[handle]
 
+        direction = np.zeros(4)
+        direction[agent.direction] = 1
         agent_pos = agents[handle].position
-        obs_map_state[agent_pos][0] += 1
-        obs_map_state[agent.target][1] += 1
+        obs_agents_state[agent_pos][:4] = direction
+        obs_targets[agent.target][0] += 1
 
         for i in range(len(agents)):
             if i != handle:  # TODO: handle used as index...?
                 agent2 = agents[i]
-                obs_other_agents_state[agent2.position][agent2.direction] = 1
-                obs_map_state[agent2.target][2] += 1
+                obs_agents_state[agent2.position][4 + agent2.direction] = 1
+                obs_targets[agent2.target][1] += 1
 
-        direction = np.zeros(4)
-        direction[agent.direction] = 1
+        return self.rail_obs, obs_agents_state, obs_targets
+
+
+class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions, flipped in the direction of the agent
+          (the agent is always heding north on the flipped view).
+
+        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
+         target and the positions of the other agents targets, also flipped depending on the agent's direction.
 
-        return self.rail_obs, obs_map_state, obs_other_agents_state, direction
+        - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
+          agents at their position coordinates, and the last channel containing the position of the given agent.
+    """
+
+    def __init__(self):
+        self.observation_space = ()
+        super(GlobalObsForRailEnvDirectionDependent, self).__init__()
+
+    def _set_env(self, env):
+        super()._set_env(env)
+
+        self.observation_space = [4, self.env.height, self.env.width]
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+                # self.rail_obs[i, j] = np.array(
+                #     list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+
+        # self.targets = np.zeros(self.env.height, self.env.width)
+        # for target_pos in self.env.agents_target:
+        #     self.targets[target_pos] += 1
+
+    def get(self, handle):
+        obs_targets = np.zeros((self.env.height, self.env.width, 2))
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
+        agents = self.env.agents
+        agent = agents[handle]
+        direction = agent.direction
+
+        idx = np.tile(np.arange(16), 2)
+
+        rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]
+
+        if direction == 1:
+            rail_obs = np.flip(rail_obs, axis=1)
+        elif direction == 2:
+            rail_obs = np.flip(rail_obs)
+        elif direction == 3:
+            rail_obs = np.flip(rail_obs, axis=0)
+
+        agent_pos = agents[handle].position
+        obs_agents_state[agent_pos][0] = 1
+        obs_targets[agent.target][0] += 1
+
+        idx = np.tile(np.arange(4), 2)
+        for i in range(len(agents)):
+            if i != handle:  # TODO: handle used as index...?
+                agent2 = agents[i]
+                obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
+                obs_targets[agent2.target][1] += 1
+
+        return rail_obs, obs_agents_state, obs_targets
 
 
 class LocalObsForRailEnv(ObservationBuilder):
     """
-    Gives a global observation of the entire rail environment.
+    Gives a local observation of the rail environment around the agent.
     The observation is composed of the following elements:
 
         - transition map array of the local environment around the given agent,
@@ -620,3 +689,10 @@ class LocalObsForRailEnv(ObservationBuilder):
 
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
+
+# class LocalObsForRailEnvImproved(ObservationBuilder):
+#     """
+#     Returns a local observation around the given agent
+#     """
+
+
-- 
GitLab