Skip to content
Snippets Groups Projects
Commit 1c652d1c authored by u229589's avatar u229589
Browse files

rename features in namedtuple Node in TreeObsForRailEnv

parent 2e6c222d
No related branches found
No related tags found
No related merge requests found
......@@ -16,17 +16,17 @@ from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
Node = collections.namedtuple('Node', 'dist_1 '
'dist_2 '
'dist_3 '
'dist_4 '
'dist_5 '
'dist_6 '
'dist_7 '
'num_agents_8 '
'num_agents_9 '
'num_agents_10 '
'speed_11 '
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'childs')
"""
TreeObsForRailEnv object.
......@@ -53,7 +53,6 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_direction = {}
self.predictor = predictor
self.location_has_target = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
def reset(self):
......@@ -181,12 +180,15 @@ class TreeObsForRailEnv(ObservationBuilder):
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0,
dist_7=distance_map[(handle, *agent.position, agent.direction)],
num_agents_8=0, num_agents_9=0,
num_agents_10=agent.malfunction_data['malfunction'],
speed_11=agent.speed_data['speed'],
childs={})
root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[(handle, *agent.position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
childs={})
visited = OrderedSet()
......@@ -215,15 +217,6 @@ class TreeObsForRailEnv(ObservationBuilder):
return root_node_observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
num_observations = 0
pow4 = 1
for i in range(remaining_depth):
num_observations += pow4
pow4 *= 4
return num_observations * self.observation_dim
def _explore_branch(self, handle, position, direction, tot_dist, depth):
"""
Utility function to compute tree-based observations.
......@@ -398,37 +391,28 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to append new / different features for each visited cell!
if last_is_target:
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=0,
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
dist_to_next_branch = tot_dist,
dist_min_to_target = 0,
elif last_is_terminal:
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=np.inf,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
dist_to_next_branch = np.inf,
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction],
else:
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
dist_to_next_branch = tot_dist,
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction],
node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict,
dist_unusable_switch=unusable_switch,
dist_to_next_branch=dist_to_next_branch,
dist_min_to_target=dist_min_to_target,
num_agents_same_direction=other_agent_same_direction,
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
childs={})
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
......@@ -475,10 +459,13 @@ class TreeObsForRailEnv(ObservationBuilder):
for direction in self.tree_explorted_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
def print_node_features(self, node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4,
", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9,
", ", node.num_agents_10, ", ", node.speed_11)
@staticmethod
def print_node_features(node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
......@@ -493,37 +480,6 @@ class TreeObsForRailEnv(ObservationBuilder):
for direction in self.tree_explorted_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < self.observation_dim:
return
depth = 0
tmp = len(tree) / self.observation_dim - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
unfolded = {}
unfolded[''] = tree[0:self.observation_dim]
child_size = (len(tree) - self.observation_dim) // 4
for child in range(4):
child_tree = tree[(self.observation_dim + child * child_size):
(self.observation_dim + (child + 1) * child_size)]
observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
if observation_tree is not None:
if actions_for_display:
label = self.tree_explorted_actions_char[child]
else:
label = self.tree_explored_actions[child]
unfolded[label] = observation_tree
return unfolded
def set_env(self, env: Environment):
super().set_env(env)
if self.predictor:
......
......@@ -277,17 +277,17 @@ def test_shortest_path_predictor_conflicts(rendering=False):
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt)
assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_9
conflict = tree.childs[a_1].num_agents_opposite_direction
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
conflict = tree.childs[a_1].childs[a_2].num_agents_9
conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment