From 8176e2fe7fc6b492576bd1946a3897a626918ed2 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 3 Jul 2019 16:14:41 -0400 Subject: [PATCH] updated comment to make tree observation more understandable --- examples/training_example.py | 2 + flatland/envs/observations.py | 79 +++++++++++++++++++++-------------- flatland/utils/rendertools.py | 2 +- 3 files changed, 51 insertions(+), 32 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index dd9ded92..cdbd1ade 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): action = agent.act(obs[a]) action_dict.update({a: action}) + # Uncomment next line to print observation of an agent + # TreeObservation.util_print_obs_subtree((obs[a])) # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8ed455ca..8222115e 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_dim = 8 + self.observation_dim = 9 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} @@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder): #3: if another agent is detected the distance in number of cells from current agent position is stored. - #4: This feature stores the distance in number of cells to the next branching store (current node) + #4: possible conflict detected + tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the + distance in number of cells from current agent position - #5: minimum distance from node to the agent's target given the direction of the agent if this path is chosen + 0 = No other agent reserve the same cell at similar time + + #5: if an not usable switch (for agent) is detected we store the distance. + + #6: This feature stores the distance in number of cells to the next branching (current node) + + #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen - #6: agent in the same direction + #8: agent in the same direction n = number of agents present same direction (possible future use: number of other agents in the same direction in this branch) 0 = no agent present same direction - #7: agent in the opposite drection + #9: agent in the opposite drection n = number of agents present other direction than myself (so conflict) (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself - #8: possible conflict detected - 1 = Other agent predicts to pass along this cell at the same time as the agent - 0 = No other agent reserve the same cell at similar time Missing/padding nodes are filled in with -inf (truncated). @@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder): num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0] + observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] root_observation = observation[:] visited = set() @@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder): def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth): """ Utility function to compute tree-based observations. + We walk along the branch and collect the information documented in the get() function. + If there is a branching point a new node is created and each possible branch is explored. """ # [Recursive branch opened] if depth >= self.max_depth + 1: @@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder): own_target_encountered = np.inf other_agent_encountered = np.inf other_target_encountered = np.inf + potential_conflict = np.inf + unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 - potential_conflict = 0 + num_steps = 1 while exploring: # ############################# @@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder): # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. if position in self.location_has_agent: - if num_steps < other_agent_encountered: - other_agent_encountered = num_steps + if tot_dist < other_agent_encountered: + other_agent_encountered = tot_dist if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction @@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder): if int_position in np.delete(self.predicted_pos[tot_dist], handle): conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position) for ca in conflicting_agent: - if direction != self.predicted_dir[tot_dist][ca[0]]: - potential_conflict = 1 + if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent: - if direction != self.predicted_dir[pre_step][ca[0]]: - potential_conflict = 1 + if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle): conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) for ca in conflicting_agent: - if direction != self.predicted_dir[post_step][ca[0]]: - potential_conflict = 1 + if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict: + potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: - if num_steps < other_target_encountered: - other_target_encountered = num_steps + if tot_dist < other_target_encountered: + other_target_encountered = tot_dist if position == agent.target: - if num_steps < own_target_encountered: - own_target_encountered = num_steps + if tot_dist < own_target_encountered: + own_target_encountered = tot_dist # ############################# # ############################# @@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder): break cell_transitions = self.env.rail.get_transitions((*position, direction)) + total_transitions = bin(self.env.rail.get_transitions(position)).count("1") num_transitions = np.count_nonzero(cell_transitions) exploring = False + # Detect Switches that can only be used by other agents. + if total_transitions > 2 > num_transitions: + unusable_switch = tot_dist + if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = 0 @@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder): observation = [own_target_encountered, other_target_encountered, other_agent_encountered, - root_observation[3] + num_steps, + potential_conflict, + unusable_switch, + tot_dist, 0, other_agent_same_direction, - other_agent_opposite_direction, - potential_conflict + other_agent_opposite_direction ] elif last_isTerminal: observation = [own_target_encountered, other_target_encountered, other_agent_encountered, + potential_conflict, + unusable_switch, np.inf, - np.inf, + self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, - other_agent_opposite_direction, - potential_conflict + other_agent_opposite_direction ] else: observation = [own_target_encountered, other_target_encountered, other_agent_encountered, - root_observation[3] + num_steps, + potential_conflict, + unusable_switch, + tot_dist, self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, - potential_conflict ] # ############################# # ############################# @@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder): return observation, visited - def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0): + def util_print_obs_subtree(self, tree, num_features_per_node=9, prompt='', current_depth=0): """ Utility function to pretty-print tree observations returned by this object. """ diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 7932aaca..537d8be8 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -38,7 +38,7 @@ class RenderTool(object): gTheta = np.linspace(0, np.pi / 2, 5) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] - def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.AGENT_SHOWS_OPTIONS): + def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.ONE_STEP_BEHIND): self.env = env self.iFrame = 0 self.time1 = time.time() -- GitLab