Skip to content
Snippets Groups Projects
Commit 8176e2fe authored by Erik Nygren's avatar Erik Nygren
Browse files

updated comment to make tree observation more understandable

parent 30a149a5
No related branches found
No related tags found
No related merge requests found
......@@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Uncomment next line to print observation of an agent
# TreeObservation.util_print_obs_subtree((obs[a]))
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
......
......@@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_dim = 8
self.observation_dim = 9
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
......@@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: if another agent is detected the distance in number of cells from current agent position is stored.
#4: This feature stores the distance in number of cells to the next branching store (current node)
#4: possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
distance in number of cells from current agent position
#5: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
0 = No other agent reserve the same cell at similar time
#5: if an not usable switch (for agent) is detected we store the distance.
#6: This feature stores the distance in number of cells to the next branching (current node)
#7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#6: agent in the same direction
#8: agent in the same direction
n = number of agents present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#7: agent in the opposite drection
#9: agent in the opposite drection
n = number of agents present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
#8: possible conflict detected
1 = Other agent predicts to pass along this cell at the same time as the agent
0 = No other agent reserve the same cell at similar time
Missing/padding nodes are filled in with -inf (truncated).
......@@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:]
visited = set()
......@@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder):
def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
"""
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
"""
# [Recursive branch opened]
if depth >= self.max_depth + 1:
......@@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder):
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
potential_conflict = np.inf
unusable_switch = np.inf
other_agent_same_direction = 0
other_agent_opposite_direction = 0
potential_conflict = 0
num_steps = 1
while exploring:
# #############################
......@@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent:
if num_steps < other_agent_encountered:
other_agent_encountered = num_steps
if tot_dist < other_agent_encountered:
other_agent_encountered = tot_dist
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
......@@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder):
if int_position in np.delete(self.predicted_pos[tot_dist], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[tot_dist][ca[0]]:
potential_conflict = 1
if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]:
potential_conflict = 1
if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered:
other_target_encountered = num_steps
if tot_dist < other_target_encountered:
other_target_encountered = tot_dist
if position == agent.target:
if num_steps < own_target_encountered:
own_target_encountered = num_steps
if tot_dist < own_target_encountered:
own_target_encountered = tot_dist
# #############################
# #############################
......@@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder):
break
cell_transitions = self.env.rail.get_transitions((*position, direction))
total_transitions = bin(self.env.rail.get_transitions(position)).count("1")
num_transitions = np.count_nonzero(cell_transitions)
exploring = False
# Detect Switches that can only be used by other agents.
if total_transitions > 2 > num_transitions:
unusable_switch = tot_dist
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = 0
......@@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
potential_conflict,
unusable_switch,
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction,
potential_conflict
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf,
np.inf,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
potential_conflict
other_agent_opposite_direction
]
else:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
potential_conflict,
unusable_switch,
tot_dist,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
potential_conflict
]
# #############################
# #############################
......@@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0):
def util_print_obs_subtree(self, tree, num_features_per_node=9, prompt='', current_depth=0):
"""
Utility function to pretty-print tree observations returned by this object.
"""
......
......@@ -38,7 +38,7 @@ class RenderTool(object):
gTheta = np.linspace(0, np.pi / 2, 5)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.AGENT_SHOWS_OPTIONS):
def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.ONE_STEP_BEHIND):
self.env = env
self.iFrame = 0
self.time1 = time.time()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment