diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index d7f5ce84eed1f94a6a35d4ca7a1657665ba1b347..a8883824d59f23e2efab69cb53070b0bb18e761d 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -653,10 +653,8 @@ class GlobalObsForRailEnv(ObservationBuilder): agents = self.env.agents agent = agents[handle] - direction = np.zeros(4) - direction[agent.direction] = 1 agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = direction + obs_agents_state[agent_pos][0] = agents[handle].direction obs_targets[agent.target][0] = 1 for i in range(len(agents)): @@ -664,8 +662,8 @@ class GlobalObsForRailEnv(ObservationBuilder): agent2 = agents[i] obs_agents_state[agent2.position][1] = agent2.direction obs_targets[agent2.target][1] = 1 - obs_agents_state[agent2.position][2] = agent2.malfunction_data['malfunction'] - obs_agents_state[agent2.position][3] = agent2.speed_data['speed'] + obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction'] + obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed'] return self.rail_obs, obs_agents_state, obs_targets