Skip to content
Snippets Groups Projects
Commit 8176e2fe authored by Erik Nygren's avatar Erik Nygren
Browse files

updated comment to make tree observation more understandable

parent 30a149a5
No related branches found
No related tags found
No related merge requests found
...@@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1): ...@@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = agent.act(obs[a]) action = agent.act(obs[a])
action_dict.update({a: action}) action_dict.update({a: action})
# Uncomment next line to print observation of an agent
# TreeObservation.util_print_obs_subtree((obs[a]))
# Environment step which returns the observations for all agents, their corresponding # Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done # reward and whether their are done
......
...@@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1): for i in range(self.max_depth + 1):
size += pow4 size += pow4
pow4 *= 4 pow4 *= 4
self.observation_dim = 8 self.observation_dim = 9
self.observation_space = [size * self.observation_dim] self.observation_space = [size * self.observation_dim]
self.location_has_agent = {} self.location_has_agent = {}
self.location_has_agent_direction = {} self.location_has_agent_direction = {}
...@@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: if another agent is detected the distance in number of cells from current agent position is stored. #3: if another agent is detected the distance in number of cells from current agent position is stored.
#4: This feature stores the distance in number of cells to the next branching store (current node) #4: possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
distance in number of cells from current agent position
#5: minimum distance from node to the agent's target given the direction of the agent if this path is chosen 0 = No other agent reserve the same cell at similar time
#5: if an not usable switch (for agent) is detected we store the distance.
#6: This feature stores the distance in number of cells to the next branching (current node)
#7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#6: agent in the same direction #8: agent in the same direction
n = number of agents present same direction n = number of agents present same direction
(possible future use: number of other agents in the same direction in this branch) (possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction 0 = no agent present same direction
#7: agent in the opposite drection #9: agent in the opposite drection
n = number of agents present other direction than myself (so conflict) n = number of agents present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts) (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself 0 = no agent present other direction than myself
#8: possible conflict detected
1 = Other agent predicts to pass along this cell at the same time as the agent
0 = No other agent reserve the same cell at similar time
Missing/padding nodes are filled in with -inf (truncated). Missing/padding nodes are filled in with -inf (truncated).
...@@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_transitions = np.count_nonzero(possible_transitions) num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position # Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0] observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:] root_observation = observation[:]
visited = set() visited = set()
...@@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder):
def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth): def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
""" """
Utility function to compute tree-based observations. Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
""" """
# [Recursive branch opened] # [Recursive branch opened]
if depth >= self.max_depth + 1: if depth >= self.max_depth + 1:
...@@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder):
own_target_encountered = np.inf own_target_encountered = np.inf
other_agent_encountered = np.inf other_agent_encountered = np.inf
other_target_encountered = np.inf other_target_encountered = np.inf
potential_conflict = np.inf
unusable_switch = np.inf
other_agent_same_direction = 0 other_agent_same_direction = 0
other_agent_opposite_direction = 0 other_agent_opposite_direction = 0
potential_conflict = 0
num_steps = 1 num_steps = 1
while exploring: while exploring:
# ############################# # #############################
...@@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to compute any useful data required to build the end node's features. This code is called # Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end. # for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent: if position in self.location_has_agent:
if num_steps < other_agent_encountered: if tot_dist < other_agent_encountered:
other_agent_encountered = num_steps other_agent_encountered = tot_dist
if self.location_has_agent_direction[position] == direction: if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction # Cummulate the number of agents on branch with same direction
...@@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder):
if int_position in np.delete(self.predicted_pos[tot_dist], handle): if int_position in np.delete(self.predicted_pos[tot_dist], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position) conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
for ca in conflicting_agent: for ca in conflicting_agent:
if direction != self.predicted_dir[tot_dist][ca[0]]: if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict:
potential_conflict = 1 potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1 # Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle): elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent: for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]: if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict:
potential_conflict = 1 potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1 # Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle): elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent: for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]: if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict:
potential_conflict = 1 potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target: if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered: if tot_dist < other_target_encountered:
other_target_encountered = num_steps other_target_encountered = tot_dist
if position == agent.target: if position == agent.target:
if num_steps < own_target_encountered: if tot_dist < own_target_encountered:
own_target_encountered = num_steps own_target_encountered = tot_dist
# ############################# # #############################
# ############################# # #############################
...@@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder):
break break
cell_transitions = self.env.rail.get_transitions((*position, direction)) cell_transitions = self.env.rail.get_transitions((*position, direction))
total_transitions = bin(self.env.rail.get_transitions(position)).count("1")
num_transitions = np.count_nonzero(cell_transitions) num_transitions = np.count_nonzero(cell_transitions)
exploring = False exploring = False
# Detect Switches that can only be used by other agents.
if total_transitions > 2 > num_transitions:
unusable_switch = tot_dist
if num_transitions == 1: if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction # Check if dead-end, or if we can go forward along direction
nbits = 0 nbits = 0
...@@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [own_target_encountered, observation = [own_target_encountered,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, potential_conflict,
unusable_switch,
tot_dist,
0, 0,
other_agent_same_direction, other_agent_same_direction,
other_agent_opposite_direction, other_agent_opposite_direction
potential_conflict
] ]
elif last_isTerminal: elif last_isTerminal:
observation = [own_target_encountered, observation = [own_target_encountered,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf, np.inf,
np.inf, self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction, other_agent_same_direction,
other_agent_opposite_direction, other_agent_opposite_direction
potential_conflict
] ]
else: else:
observation = [own_target_encountered, observation = [own_target_encountered,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, potential_conflict,
unusable_switch,
tot_dist,
self.distance_map[handle, position[0], position[1], direction], self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction, other_agent_same_direction,
other_agent_opposite_direction, other_agent_opposite_direction,
potential_conflict
] ]
# ############################# # #############################
# ############################# # #############################
...@@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation, visited return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0): def util_print_obs_subtree(self, tree, num_features_per_node=9, prompt='', current_depth=0):
""" """
Utility function to pretty-print tree observations returned by this object. Utility function to pretty-print tree observations returned by this object.
""" """
......
...@@ -38,7 +38,7 @@ class RenderTool(object): ...@@ -38,7 +38,7 @@ class RenderTool(object):
gTheta = np.linspace(0, np.pi / 2, 5) gTheta = np.linspace(0, np.pi / 2, 5)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.AGENT_SHOWS_OPTIONS): def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.ONE_STEP_BEHIND):
self.env = env self.env = env
self.iFrame = 0 self.iFrame = 0
self.time1 = time.time() self.time1 = time.time()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment