From 69513a408c99b239387556feaff931365b21a83d Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 09:03:34 -0400
Subject: [PATCH] fixed observations bugs in global observation

---
 flatland/envs/observations.py | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index d7f5ce84..a8883824 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -653,10 +653,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
         agents = self.env.agents
         agent = agents[handle]
 
-        direction = np.zeros(4)
-        direction[agent.direction] = 1
         agent_pos = agents[handle].position
-        obs_agents_state[agent_pos][0] = direction
+        obs_agents_state[agent_pos][0] = agents[handle].direction
         obs_targets[agent.target][0] = 1
 
         for i in range(len(agents)):
@@ -664,8 +662,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 agent2 = agents[i]
                 obs_agents_state[agent2.position][1] = agent2.direction
                 obs_targets[agent2.target][1] = 1
-            obs_agents_state[agent2.position][2] = agent2.malfunction_data['malfunction']
-            obs_agents_state[agent2.position][3] = agent2.speed_data['speed']
+            obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction']
+            obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed']
 
         return self.rail_obs, obs_agents_state, obs_targets
 
-- 
GitLab