From 437e6193a1277937b0de74923f829902bb1d7669 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 2 Oct 2019 17:52:53 -0400
Subject: [PATCH] updated global observation to account for multiple agents

---
 flatland/envs/observations.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index c23d4345..548c3e29 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -492,11 +492,12 @@ class GlobalObsForRailEnv(ObservationBuilder):
         - transition map array with dimensions (env.height, env.width, 16),\
           assuming 16 bits encoding of transitions.
 
-        - A 3D array (map_height, map_width, 4) with
+        - A 3D array (map_height, map_width, 5) with
             - first channel containing the agents position and direction
             - second channel containing the other agents positions and diretion
             - third channel containing agent/other agent malfunctions
             - fourth channel containing agent/other agent fractional speeds
+            ' fifth channel containing number of agents in cell (only larger then one at start position)
 
         - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
          target and the positions of the other agents targets.
@@ -519,14 +520,16 @@ class GlobalObsForRailEnv(ObservationBuilder):
     def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
 
         obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1
-
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
+        obs_agents_state[0] -= 1  # Set all values to -1 to avoid confusion with orientation
+        obs_agents_state[1] -= 1  # Set all values to -1 to avoid confusion with orientation
         agent = self.env.agents[handle]
         obs_agents_state[agent.position][0] = agent.direction
         obs_targets[agent.target][0] = 1
 
         for i in range(len(self.env.agents)):
             other_agent = self.env.agents[i]
+            obs_agents_state[other_agent.position][4] += 1
             if i != handle:
                 obs_agents_state[other_agent.position][1] = other_agent.direction
                 obs_targets[other_agent.target][1] = 1
-- 
GitLab