From 33c7a0a180a6a31cd981c4062f117a75e3c9f18c Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Wed, 22 May 2019 14:22:55 +0200 Subject: [PATCH] local observation improved with direction --- flatland/envs/observations.py | 38 +++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 09739f87..70000131 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -526,7 +526,6 @@ class GlobalObsForRailEnv(ObservationBuilder): return self.rail_obs, obs_map_state, obs_other_agents_state, direction - class LocalObsForRailEnv(ObservationBuilder): """ Gives a global observation of the entire rail environment. @@ -536,8 +535,11 @@ class LocalObsForRailEnv(ObservationBuilder): with dimensions (2*view_radius + 1, 2*view_radius + 1, 16), assuming 16 bits encoding of transitions. - - Three 2D arrays containing respectively, if they are in the agent's vision range, - its target position, the positions of the other agents and of their target. + - Two 2D arrays containing respectively, if they are in the agent's vision range, + its target position, the positions of the other targets. + + - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions + of the other agents at their position coordinates, if they are in the agent's vision range. - A 4 elements array with one hot encoding of the direction. """ @@ -555,10 +557,13 @@ class LocalObsForRailEnv(ObservationBuilder): self.rail_obs = np.zeros((self.env.height + 2*self.view_radius, self.env.width + 2*self.view_radius, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array( - list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) + for i in range(self.env.height): + for j in range(self.env.width): + bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist) + # self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array( + # list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) def get(self, handle): agents = self.env.agents @@ -569,20 +574,22 @@ class LocalObsForRailEnv(ObservationBuilder): # top_offset = max(0, agent.position[0] - 1 - self.view_radius) # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius) - local_rail_obs = self.rail_obs[agent.position: agent.position+2*self.view_radius +1, - agent.position:agent.position+2*self.view_radius +1] + local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius +1, + agent.position[1]:agent.position[1]+2*self.view_radius +1] + + obs_map_state = np.zeros((2*self.view_radius +1, 2*self.view_radius + 1, 2)) - obs = np.zeros((3, 2*self.view_radius +1, 2*self.view_radius + 1)) + obs_other_agents_state = np.zeros((2*self.view_radius +1, 2*self.view_radius +1, 4)) def relative_pos(pos): return [agent.position[0] - pos[0], agent.position[1] - pos[1]] def is_in(rel_pos): - return abs(rel_pos) <= self.view_radius + return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius) target_rel_pos = relative_pos(agent.target) if is_in(target_rel_pos): - obs[0][self.view_radius + 1 + np.array(target_rel_pos)] += 1 + obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1 for i in range(len(agents)): if i != handle: # TODO: handle used as index...? @@ -590,14 +597,15 @@ class LocalObsForRailEnv(ObservationBuilder): agent_2_rel_pos = relative_pos(agent2.position) if is_in(agent_2_rel_pos): - obs[1][self.view_radius + 1 + np.array(agent_2_rel_pos)] += 1 + obs_other_agents_state[self.view_radius + agent_2_rel_pos[0], + self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1 target_rel_pos_2 = relative_pos(agent2.position) if is_in(target_rel_pos_2): - obs[2][self.view_radius + 1 + np.array(target_rel_pos_2)] += 1 + obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1 direction = np.zeros(4) direction[agent.direction] = 1 - return local_rail_obs, obs, direction + return local_rail_obs, obs_map_state, obs_other_agents_state, direction -- GitLab