From aadff79078f53eb8ed4c686a5976c8d71955f57a Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 4 Jun 2019 13:47:50 +0200 Subject: [PATCH] 47 agent directions in observation --- flatland/core/env_observation_builder.py | 9 +++- flatland/envs/observations.py | 69 ++++++++++++++---------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index f85afee4..b30c2b1f 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -7,13 +7,14 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and + Get() is called whenever an observation has to be computed, potentially for each agent independently in case of multi-agent environments. """ +import numpy as np class ObservationBuilder: """ ObservationBuilder base class. - Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned + Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned observations. """ @@ -45,3 +46,9 @@ class ObservationBuilder: An observation structure, specific to the corresponding environment. """ raise NotImplementedError() + + def _get_one_hot_for_agent_direction(self, agent): + """Retuns the agent's direction to one-hot encoding.""" + direction = np.zeros(4) + direction[agent.direction] = 1 + return direction diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index f5f70aa1..22ebf264 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -1,9 +1,10 @@ """ Collection of environment-specific ObservationBuilder. """ -import numpy as np from collections import deque +import numpy as np + from flatland.core.env_observation_builder import ObservationBuilder @@ -22,10 +23,10 @@ class TreeObsForRailEnv(ObservationBuilder): # Compute the size of the returned observation vector size = 0 pow4 = 1 - for i in range(self.max_depth+1): + for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_space = [size * 5] + self.observation_space = [size * 6] def reset(self): agents = self.env.agents @@ -186,6 +187,10 @@ class TreeObsForRailEnv(ObservationBuilder): #5: minimum distance from node to the agent's target (when landing to the node following the corresponding branch. + #6: agent direction + + + Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -202,13 +207,10 @@ class TreeObsForRailEnv(ObservationBuilder): if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index - # position = self.env.agents_position[handle] - # orientation = self.env.agents_direction[handle] possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction] root_observation = observation[:] visited = set() # Start from the current orientation, and see which transitions are available; @@ -337,40 +339,49 @@ class TreeObsForRailEnv(ObservationBuilder): 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, root_observation[3] + num_steps, - 0] + 0, + direction] elif last_isTerminal: observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, np.inf, - np.inf] + np.inf, + direction] else: observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, root_observation[3] + num_steps, - self.distance_map[handle, position[0], position[1], direction]] + self.distance_map[handle, position[0], position[1], direction], + direction] """ if last_isTarget: observation = [0, other_target_encountered, other_agent_encountered, root_observation[3] + num_steps, - 0] + 0, + direction + ] elif last_isTerminal: observation = [0, other_target_encountered, other_agent_encountered, np.inf, - np.inf] + np.inf, + direction + ] else: observation = [0, other_target_encountered, other_agent_encountered, root_observation[3] + num_steps, - self.distance_map[handle, position[0], position[1], direction]] + self.distance_map[handle, position[0], position[1], direction], + direction + ] # ############################# # ############################# @@ -409,7 +420,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth - depth): num_cells_to_fill_in += pow4 pow4 *= 4 - observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in return observation, visited @@ -532,7 +543,9 @@ class GlobalObsForRailEnv(ObservationBuilder): obs_agents_state[agent2.position][4 + agent2.direction] = 1 obs_targets[agent2.target][1] += 1 - return self.rail_obs, obs_agents_state, obs_targets + direction = self._get_one_hot_for_agent_direction(agent) + + return self.rail_obs, obs_agents_state, obs_targets, direction class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): @@ -542,13 +555,15 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16), assuming 16 bits encoding of transitions, flipped in the direction of the agent - (the agent is always heding north on the flipped view). + (the agent is always heading north on the flipped view). - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent target and the positions of the other agents targets, also flipped depending on the agent's direction. - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other agents at their position coordinates, and the last channel containing the position of the given agent. + + - A 4 elements array with one hot encoding of the direction. """ def __init__(self): @@ -603,7 +618,9 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1 obs_targets[agent2.target][1] += 1 - return rail_obs, obs_agents_state, obs_targets + direction = self._get_one_hot_for_agent_direction(agent) + + return rail_obs, obs_agents_state, obs_targets, direction class LocalObsForRailEnv(ObservationBuilder): @@ -635,8 +652,8 @@ class LocalObsForRailEnv(ObservationBuilder): # We build the transition map with a view_radius empty cells expansion on each side. # This helps to collect the local transition map view when the agent is close to a border. - self.rail_obs = np.zeros((self.env.height + 2*self.view_radius, - self.env.width + 2*self.view_radius, 16)) + self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius, + self.env.width + 2 * self.view_radius, 16)) for i in range(self.env.height): for j in range(self.env.width): bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] @@ -654,12 +671,12 @@ class LocalObsForRailEnv(ObservationBuilder): # top_offset = max(0, agent.position[0] - 1 - self.view_radius) # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius) - local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius + 1, - agent.position[1]:agent.position[1]+2*self.view_radius + 1] + local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1, + agent.position[1]:agent.position[1] + 2 * self.view_radius + 1] - obs_map_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 2)) + obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2)) - obs_other_agents_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 4)) + obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4)) def relative_pos(pos): return [agent.position[0] - pos[0], agent.position[1] - pos[1]] @@ -684,15 +701,11 @@ class LocalObsForRailEnv(ObservationBuilder): if is_in(target_rel_pos_2): obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1 - direction = np.zeros(4) - direction[agent.direction] = 1 + direction = self._get_one_hot_for_agent_direction(agent) return local_rail_obs, obs_map_state, obs_other_agents_state, direction - # class LocalObsForRailEnvImproved(ObservationBuilder): # """ # Returns a local observation around the given agent # """ - - -- GitLab