Commit 8176e2fe authored by Erik Nygren's avatar Erik Nygren
Browse files

updated comment to make tree observation more understandable

parent 30a149a5
Pipeline #1318 failed with stage
in 6 minutes and 21 seconds
......@@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Uncomment next line to print observation of an agent
# TreeObservation.util_print_obs_subtree((obs[a]))
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
......
......@@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_dim = 8
self.observation_dim = 9
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
......@@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: if another agent is detected the distance in number of cells from current agent position is stored.
#4: This feature stores the distance in number of cells to the next branching store (current node)
#4: possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
distance in number of cells from current agent position
#5: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
0 = No other agent reserve the same cell at similar time
#5: if an not usable switch (for agent) is detected we store the distance.
#6: This feature stores the distance in number of cells to the next branching (current node)
#7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#6: agent in the same direction
#8: agent in the same direction
n = number of agents present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#7: agent in the opposite drection
#9: agent in the opposite drection
n = number of agents present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
#8: possible conflict detected
1 = Other agent predicts to pass along this cell at the same time as the agent
0 = No other agent reserve the same cell at similar time
Missing/padding nodes are filled in with -inf (truncated).
......@@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:]
visited = set()
......@@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder):
def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
"""
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
"""
# [Recursive branch opened]
if depth >= self.max_depth + 1:
......@@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder):
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
potential_conflict = np.inf
unusable_switch = np.inf
other_agent_same_direction = 0
other_agent_opposite_direction = 0
potential_conflict = 0
num_steps = 1
while exploring:
# #############################
......@@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent:
if num_steps < other_agent_encountered:
other_agent_encountered = num_steps
if tot_dist < other_agent_encountered:
other_agent_encountered = tot_dist
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
......@@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder):
if int_position in np.delete(self.predicted_pos[tot_dist], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[tot_dist][ca[0]]:
potential_conflict = 1
if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca[0]]:
potential_conflict = 1
if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle):
conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca[0]]:
potential_conflict = 1
if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered:
other_target_encountered = num_steps
if tot_dist < other_target_encountered:
other_target_encountered = tot_dist
if position == agent.target:
if num_steps < own_target_encountered:
own_target_encountered = num_steps
if tot_dist < own_target_encountered:
own_target_encountered = tot_dist
# #############################
# #############################
......@@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder):
break
cell_transitions = self.env.rail.get_transitions((*position, direction))
total_transitions = bin(self.env.rail.get_transitions(position)).count("1")
num_transitions = np.count_nonzero(cell_transitions)
exploring = False
# Detect Switches that can only be used by other agents.
if total_transitions > 2 > num_transitions:
unusable_switch = tot_dist
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = 0
......@@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
potential_conflict,
unusable_switch,
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction,
potential_conflict
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf,
np.inf,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
potential_conflict
other_agent_opposite_direction
]
else:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
potential_conflict,
unusable_switch,
tot_dist,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
potential_conflict
]
# #############################
# #############################
......@@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0):
def util_print_obs_subtree(self, tree, num_features_per_node=9, prompt='', current_depth=0):
"""
Utility function to pretty-print tree observations returned by this object.
"""
......
......@@ -38,7 +38,7 @@ class RenderTool(object):
gTheta = np.linspace(0, np.pi / 2, 5)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.AGENT_SHOWS_OPTIONS):
def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.ONE_STEP_BEHIND):
self.env = env
self.iFrame = 0
self.time1 = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment