diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 18b08f756ff9bc8aa5414d170a4b791b21a4b24a..4d32faa42c17d3f7bfe908b631c02fc29a1ef2f2 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -26,7 +26,10 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_space = [size * 6] + self.observation_dim = 7 + self.observation_space = [size * self.observation_dim] + self.location_has_agent = {} + self.location_has_agent_direction = {} def reset(self): agents = self.env.agents @@ -181,8 +184,15 @@ class TreeObsForRailEnv(ObservationBuilder): #5: minimum distance from node to the agent's target (when landing to the node following the corresponding branch. - #6: agent direction + #6: agent in the same direction + 1 = agent present same direction + (possible future use: number of other agents in the same direction in this branch) + 0 = no agent present same direction + #7: agent in the opposite drection + 1 = agent present other direction than myself (so conflict) + (possible future use: number of other agents in other direction in this branch, i.e. number of conflicts) + 0 = no agent present other direction than myself Missing/padding nodes are filled in with -inf (truncated). @@ -195,13 +205,15 @@ class TreeObsForRailEnv(ObservationBuilder): # Update local lookup table for all agents' positions self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} + self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents} if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] + root_observation = observation[:] visited = set() # Start from the current orientation, and see which transitions are available; @@ -212,7 +224,6 @@ class TreeObsForRailEnv(ObservationBuilder): if num_transitions == 1: orientation == np.argmax(possible_transitions) - # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) @@ -227,7 +238,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth): num_cells_to_fill_in += pow4 pow4 *= 4 - observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in + observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in self.env.dev_obs_dict[handle] = visited return observation @@ -275,7 +286,6 @@ class TreeObsForRailEnv(ObservationBuilder): visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. - # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: if np.array_equal(position, self.env.agents[handle].target): last_isTarget = True break @@ -297,6 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder): if not last_isDeadEnd: # Keep walking through the tree along `direction' exploring = True + # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) position = self._new_position(position, direction) num_steps += 1 @@ -321,36 +332,53 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # Modify here to append new / different features for each visited cell! """ + other_agent_same_direction = \ + 1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0 + other_agent_opposite_direction = \ + 1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0 + if last_isTarget: observation = [0, - 1 if other_target_encountered else 0, - 1 if other_agent_encountered else 0, + other_target_encountered, + other_agent_encountered, root_observation[3] + num_steps, 0, - direction] + other_agent_same_direction, + other_agent_opposite_direction + ] elif last_isTerminal: observation = [0, - 1 if other_target_encountered else 0, - 1 if other_agent_encountered else 0, + other_target_encountered, + other_agent_encountered, np.inf, np.inf, - direction] + other_agent_same_direction, + other_agent_opposite_direction + ] else: observation = [0, - 1 if other_target_encountered else 0, - 1 if other_agent_encountered else 0, + other_target_encountered, + other_agent_encountered, root_observation[3] + num_steps, self.distance_map[handle, position[0], position[1], direction], - direction] + other_agent_same_direction, + other_agent_opposite_direction + ] """ + other_agent_same_direction = \ + 1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0 + other_agent_opposite_direction = \ + 1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0 + if last_isTarget: observation = [0, other_target_encountered, other_agent_encountered, root_observation[3] + num_steps, 0, - direction + other_agent_same_direction, + other_agent_opposite_direction ] elif last_isTerminal: @@ -359,7 +387,8 @@ class TreeObsForRailEnv(ObservationBuilder): other_agent_encountered, np.inf, np.inf, - direction + other_agent_same_direction, + other_agent_opposite_direction ] else: observation = [0, @@ -367,7 +396,8 @@ class TreeObsForRailEnv(ObservationBuilder): other_agent_encountered, root_observation[3] + num_steps, self.distance_map[handle, position[0], position[1], direction], - direction + other_agent_same_direction, + other_agent_opposite_direction ] # ############################# # ############################# @@ -407,8 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth - depth): num_cells_to_fill_in += pow4 pow4 *= 4 - observation = \ - observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in + observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in return observation, visited