diff --git a/examples/training_example.py b/examples/training_example.py index 67e0c740c2b15e25c65dde8036831f3ef062a831..a05f7c727ac9a56453951de453ea184cab1ea4fc 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -14,11 +14,11 @@ np.random.seed(1) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) -env = RailEnv(width=20, - height=20, +env = RailEnv(width=50, + height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), obs_builder_object=LocalGridObs, - number_of_agents=2) + number_of_agents=5) env_renderer = RenderTool(env, gl="PILSVG", ) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index d0afc71b3d09d42cf374ef61deaba6008d7fc420..19eb9ab9ee18e8f5f7a389bfe83368cae8187fd0 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -704,7 +704,7 @@ class LocalObsForRailEnv(ObservationBuilder): - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, if they are in the agent's vision range, its target position, the positions of the other targets. - - A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions + - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions of the other agents at their position coordinates, if they are in the agent's vision range. - A 4 elements array with one hot encoding of the direction. @@ -724,43 +724,52 @@ class LocalObsForRailEnv(ObservationBuilder): # We build the transition map with a view_radius empty cells expansion on each side. # This helps to collect the local transition map view when the agent is close to a border. self.max_padding = max(self.view_width, self.view_height) - self.rail_obs = np.zeros((self.env.height + 2 * self.max_padding, - self.env.width + 2 * self.max_padding, 16)) + self.rail_obs = np.zeros((self.env.height, + self.env.width, 16)) for i in range(self.env.height): for j in range(self.env.width): bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i + self.view_height, j + self.view_width] = np.array(bitlist) + self.rail_obs[i, j] = np.array(bitlist) def get(self, handle): agents = self.env.agents agent = agents[handle] - agent_rel_pos = [0, 0] # Correct agents position for padding - agent_rel_pos[0] = agent.position[0] + self.max_padding - agent_rel_pos[1] = agent.position[1] + self.max_padding - - # Collect the rail information in the local field of view - local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs) + # agent_rel_pos[0] = agent.position[0] + self.max_padding + # agent_rel_pos[1] = agent.position[1] + self.max_padding # Collect visible cells as set to be plotted - visited, rel_coords = self.field_of_view(agent.position, agent.direction) + visited, rel_coords = self.field_of_view(agent.position, agent.direction, ) + local_rail_obs = None # Add the visible cells to the observed cells self.env.dev_obs_dict[handle] = set(visited) # Locate observed agents and their coresponding targets - obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2)) - obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4)) - - for i in range(len(agents)): - temp_agent = agents[i] - if temp_agent.target in visited: - location = np.where(visited == temp_agent.target) - print("I see my target", location, handle) - direction = self._get_one_hot_for_agent_direction(agent) - + local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16)) + obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2)) + obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4)) + _idx = 0 + for pos in visited: + curr_rel_coord = rel_coords[_idx] + local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :] + if pos == agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1 + else: + for tmp_agent in agents: + if pos == tmp_agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1 + if pos != agent.position: + for tmp_agent in agents: + if pos == tmp_agent.position: + obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[ + tmp_agent.direction] + + _idx += 1 + + direction = np.identity(4)[agent.direction] return local_rail_obs, obs_map_state, obs_other_agents_state, direction def get_many(self, handles=None): @@ -796,26 +805,26 @@ class LocalObsForRailEnv(ObservationBuilder): if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: visible.append((origin[0] - h, origin[1] + w)) rel_coords.append((h, w)) - if data_collection: - temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] elif direction == 1: if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: visible.append((origin[0] + w, origin[1] + h)) rel_coords.append((h, w)) - if data_collection: - temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] elif direction == 2: - if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width: visible.append((origin[0] + h, origin[1] - w)) rel_coords.append((h, w)) - if data_collection: - temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] else: - if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width: visible.append((origin[0] - w, origin[1] - h)) rel_coords.append((h, w)) - if data_collection: - temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] if data_collection: return temp_visible_data else: