From 69513a408c99b239387556feaff931365b21a83d Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 09:03:34 -0400 Subject: [PATCH] fixed observations bugs in global observation --- flatland/envs/observations.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index d7f5ce84..a8883824 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -653,10 +653,8 @@ class GlobalObsForRailEnv(ObservationBuilder): agents = self.env.agents agent = agents[handle] - direction = np.zeros(4) - direction[agent.direction] = 1 agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = direction + obs_agents_state[agent_pos][0] = agents[handle].direction obs_targets[agent.target][0] = 1 for i in range(len(agents)): @@ -664,8 +662,8 @@ class GlobalObsForRailEnv(ObservationBuilder): agent2 = agents[i] obs_agents_state[agent2.position][1] = agent2.direction obs_targets[agent2.target][1] = 1 - obs_agents_state[agent2.position][2] = agent2.malfunction_data['malfunction'] - obs_agents_state[agent2.position][3] = agent2.speed_data['speed'] + obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction'] + obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed'] return self.rail_obs, obs_agents_state, obs_targets -- GitLab