diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index fb8e107dead2c0bb765c01b223e2317d120dc0ee..d0afc71b3d09d42cf374ef61deaba6008d7fc420 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -701,7 +701,7 @@ class LocalObsForRailEnv(ObservationBuilder): with dimensions (view_height,2*view_width+1, 16), assuming 16 bits encoding of transitions. - - Two 3D arrays (view_height,2*view_width+1, 2) containing respectively, + - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, if they are in the agent's vision range, its target position, the positions of the other targets. - A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions @@ -744,16 +744,21 @@ class LocalObsForRailEnv(ObservationBuilder): # Collect the rail information in the local field of view local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs) - # Locate observed agents and their coresponding targets - obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2)) - obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4)) - # Collect visible cells as set to be plotted - visited = self.field_of_view(agent.position, agent.direction) + visited, rel_coords = self.field_of_view(agent.position, agent.direction) # Add the visible cells to the observed cells - self.env.dev_obs_dict[handle] = visited + self.env.dev_obs_dict[handle] = set(visited) + # Locate observed agents and their coresponding targets + obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2)) + obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4)) + + for i in range(len(agents)): + temp_agent = agents[i] + if temp_agent.target in visited: + location = np.where(visited == temp_agent.target) + print("I see my target", location, handle) direction = self._get_one_hot_for_agent_direction(agent) return local_rail_obs, obs_map_state, obs_other_agents_state, direction @@ -783,30 +788,35 @@ class LocalObsForRailEnv(ObservationBuilder): origin = (position[0] - self.center, position[1] + self.view_width) else: origin = (position[0] + self.view_width, position[1] + self.center) - visible = set() + visible = list() + rel_coords = list() for h in range(self.view_height): for w in range(2 * self.view_width + 1): if direction == 0: if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: - visible.add((origin[0] - h, origin[1] + w)) + visible.append((origin[0] - h, origin[1] + w)) + rel_coords.append((h, w)) if data_collection: temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] elif direction == 1: if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: - visible.add((origin[0] + w, origin[1] + h)) + visible.append((origin[0] + w, origin[1] + h)) + rel_coords.append((h, w)) if data_collection: temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] elif direction == 2: if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: - visible.add((origin[0] + h, origin[1] - w)) + visible.append((origin[0] + h, origin[1] - w)) + rel_coords.append((h, w)) if data_collection: temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] else: if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: - visible.add((origin[0] - w, origin[1] - h)) + visible.append((origin[0] - w, origin[1] - h)) + rel_coords.append((h, w)) if data_collection: temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] if data_collection: return temp_visible_data else: - return visible + return visible, rel_coords