From 51c747b6b4efa36bf2e326dbb3c3751aaa23ba7b Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Mon, 20 May 2019 15:25:15 +0200 Subject: [PATCH] added direction of other agents in global observation --- flatland/envs/observations.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index fa77b58a..0fd94a4e 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -203,7 +203,7 @@ class TreeObsForRailEnv(ObservationBuilder): num_transitions = np.count_nonzero(possible_transitions) # Root node - current position # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position,direc agent.direction)]] root_observation = observation[:] visited = set() # Start from the current orientation, and see which transitions are available; @@ -478,11 +478,15 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16), assuming 16 bits encoding of transitions. - - Four 2D arrays containing respectively the position of the given agent, - the position of its target, the positions of the other agents and of - their target. + - Three 2D arrays (map_height, map_width, 3) containing respectively the position of the given agent, + the position of its target and the positions of the other agents targets. + + - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions + of the other agents at their position coordinates. + + - A 4 elements array with one of encoding of the direction of the agent of interest. + - - A 4 elements array with one of encoding of the direction. """ def __init__(self): @@ -503,21 +507,22 @@ class GlobalObsForRailEnv(ObservationBuilder): # self.targets[target_pos] += 1 def get(self, handle): - obs = np.zeros((4, self.env.height, self.env.width)) + obs_map_state = np.zeros((self.env.height, self.env.width, 3)) + obs_other_agents_state = np.zeros((self.env.height, self.env.width, 4)) agents = self.env.agents agent = agents[handle] agent_pos = agents[handle].position - obs[0][agent_pos] += 1 - obs[1][agent.target] += 1 + obs_map_state[agent_pos][0] += 1 + obs_map_state[agent.target][1] += 1 for i in range(len(agents)): if i != handle: # TODO: handle used as index...? agent2 = agents[i] - obs[3][agent2.position] += 1 - obs[2][agent2.target] += 1 + obs_other_agents_state[agent2.position][agent2.direction] = 1 + obs_map_state[agent2.target][2] += 1 direction = np.zeros(4) direction[agent.direction] = 1 - return self.rail_obs, obs, direction + return self.rail_obs, obs_map_state, obs_other_agents_state, direction -- GitLab