diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 0fdec364ecc3d8415b27e2101c339d6f49e89022..d8d532c5dc9406a75371df817af550dd931a25b5 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -482,7 +482,7 @@ class GlobalObsForRailEnv(ObservationBuilder): the position of its target, the positions of the other agents and of their target. - - A 4 elements array with one of encoding of the direction. + - A 4 elements array with one hot encoding of the direction. """ def __init__(self): @@ -518,3 +518,79 @@ class GlobalObsForRailEnv(ObservationBuilder): direction[agent.direction] = 1 return self.rail_obs, obs, direction + + +class LocalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array of the local environment around the given agent, + with dimensions (2*view_radius + 1, 2*view_radius + 1, 16), + assuming 16 bits encoding of transitions. + + - Three 2D arrays containing respectively, if they are in the agent's vision range, + its target position, the positions of the other agents and of their target. + + - A 4 elements array with one hot encoding of the direction. + """ + + def __init__(self, view_radius): + """ + :param view_radius: + """ + super(LocalObsForRailEnv, self).__init__() + self.view_radius = view_radius + + def reset(self): + # We build the transition map with a view_radius empty cells expansion on each side. + # This helps to collect the local transition map view when the agent is close to a border. + + self.rail_obs = np.zeros((self.env.height + 2*self.view_radius, + self.env.width + 2*self.view_radius, 16)) + for i in range(self.rail_obs.shape[0]): + for j in range(self.rail_obs.shape[1]): + self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array( + list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) + + def get(self, handle): + agents = self.env.agents + agent = agents[handle] + + # left_offset = max(0, agent.position[1] - 1 - self.view_radius) + # right_offset = min(self.env.width, agent.position[1] + 1 + self.view_radius) + # top_offset = max(0, agent.position[0] - 1 - self.view_radius) + # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius) + + local_rail_obs = self.rail_obs[agent.position: agent.position+2*self.view_radius +1, + agent.position:agent.position+2*self.view_radius +1] + + obs = np.zeros((3, 2*self.view_radius +1, 2*self.view_radius + 1)) + + def relative_pos(pos): + return [agent.position[0] - pos[0], agent.position[1] - pos[1]] + + def is_in(rel_pos): + return abs(rel_pos) <= self.view_radius + + target_rel_pos = relative_pos(agent.target) + if is_in(target_rel_pos): + obs[0][self.view_radius + 1 + np.array(target_rel_pos)] += 1 + + for i in range(len(agents)): + if i != handle: # TODO: handle used as index...? + agent2 = agents[i] + + agent_2_rel_pos = relative_pos(agent2.position) + if is_in(agent_2_rel_pos): + obs[1][self.view_radius + 1 + np.array(agent_2_rel_pos)] += 1 + + target_rel_pos_2 = relative_pos(agent2.position) + if is_in(target_rel_pos_2): + obs[2][self.view_radius + 1 + np.array(target_rel_pos_2)] += 1 + + direction = np.zeros(4) + direction[agent.direction] = 1 + + return local_rail_obs, obs, direction +