diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4d5fb44d98698072993cb6334e298033c9914b31..51f4cff78462f56642c493263eef872647d8f0bc 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -490,8 +490,6 @@ class GlobalObsForRailEnv(ObservationBuilder): of the other agents at their position coordinates. - A 4 elements array with one of encoding of the direction of the agent of interest. - - """ def __init__(self): @@ -536,4 +534,89 @@ class GlobalObsForRailEnv(ObservationBuilder): direction = np.zeros(4) direction[agent.direction] = 1 - return self.rail_obs, obs_map_state, obs_other_agents_state, direction + return self.rail_obs, obs_map_state, obs_other_agents_state, direction + + +class LocalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array of the local environment around the given agent, + with dimensions (2*view_radius + 1, 2*view_radius + 1, 16), + assuming 16 bits encoding of transitions. + + - Two 2D arrays containing respectively, if they are in the agent's vision range, + its target position, the positions of the other targets. + + - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions + of the other agents at their position coordinates, if they are in the agent's vision range. + + - A 4 elements array with one hot encoding of the direction. + """ + + def __init__(self, view_radius): + """ + :param view_radius: + """ + super(LocalObsForRailEnv, self).__init__() + self.view_radius = view_radius + + def reset(self): + # We build the transition map with a view_radius empty cells expansion on each side. + # This helps to collect the local transition map view when the agent is close to a border. + + self.rail_obs = np.zeros((self.env.height + 2*self.view_radius, + self.env.width + 2*self.view_radius, 16)) + for i in range(self.env.height): + for j in range(self.env.width): + bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist) + # self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array( + # list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) + + def get(self, handle): + agents = self.env.agents + agent = agents[handle] + + # left_offset = max(0, agent.position[1] - 1 - self.view_radius) + # right_offset = min(self.env.width, agent.position[1] + 1 + self.view_radius) + # top_offset = max(0, agent.position[0] - 1 - self.view_radius) + # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius) + + local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius +1, + agent.position[1]:agent.position[1]+2*self.view_radius +1] + + obs_map_state = np.zeros((2*self.view_radius +1, 2*self.view_radius + 1, 2)) + + obs_other_agents_state = np.zeros((2*self.view_radius +1, 2*self.view_radius +1, 4)) + + def relative_pos(pos): + return [agent.position[0] - pos[0], agent.position[1] - pos[1]] + + def is_in(rel_pos): + return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius) + + target_rel_pos = relative_pos(agent.target) + if is_in(target_rel_pos): + obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1 + + for i in range(len(agents)): + if i != handle: # TODO: handle used as index...? + agent2 = agents[i] + + agent_2_rel_pos = relative_pos(agent2.position) + if is_in(agent_2_rel_pos): + obs_other_agents_state[self.view_radius + agent_2_rel_pos[0], + self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1 + + target_rel_pos_2 = relative_pos(agent2.position) + if is_in(target_rel_pos_2): + obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1 + + direction = np.zeros(4) + direction[agent.direction] = 1 + + return local_rail_obs, obs_map_state, obs_other_agents_state, direction +