Commit 1c652d1c authored by u229589's avatar u229589
Browse files

rename features in namedtuple Node in TreeObsForRailEnv

parent 2e6c222d
......@@ -16,17 +16,17 @@ from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
Node = collections.namedtuple('Node', 'dist_1 '
'dist_2 '
'dist_3 '
'dist_4 '
'dist_5 '
'dist_6 '
'dist_7 '
'num_agents_8 '
'num_agents_9 '
'num_agents_10 '
'speed_11 '
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'childs')
"""
TreeObsForRailEnv object.
......@@ -53,7 +53,6 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_direction = {}
self.predictor = predictor
self.location_has_target = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
def reset(self):
......@@ -181,12 +180,15 @@ class TreeObsForRailEnv(ObservationBuilder):
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0,
dist_7=distance_map[(handle, *agent.position, agent.direction)],
num_agents_8=0, num_agents_9=0,
num_agents_10=agent.malfunction_data['malfunction'],
speed_11=agent.speed_data['speed'],
childs={})
root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[(handle, *agent.position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
childs={})
visited = OrderedSet()
......@@ -215,15 +217,6 @@ class TreeObsForRailEnv(ObservationBuilder):
return root_node_observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
num_observations = 0
pow4 = 1
for i in range(remaining_depth):
num_observations += pow4
pow4 *= 4
return num_observations * self.observation_dim
def _explore_branch(self, handle, position, direction, tot_dist, depth):
"""
Utility function to compute tree-based observations.
......@@ -398,37 +391,28 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to append new / different features for each visited cell!
if last_is_target:
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=0,
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
dist_to_next_branch = tot_dist,
dist_min_to_target = 0,
elif last_is_terminal:
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=np.inf,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
dist_to_next_branch = np.inf,
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction],
else:
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
dist_to_next_branch = tot_dist,
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction],
node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict,
dist_unusable_switch=unusable_switch,
dist_to_next_branch=dist_to_next_branch,
dist_min_to_target=dist_min_to_target,
num_agents_same_direction=other_agent_same_direction,
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
childs={})
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
......@@ -475,10 +459,13 @@ class TreeObsForRailEnv(ObservationBuilder):
for direction in self.tree_explorted_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
def print_node_features(self, node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4,
", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9,
", ", node.num_agents_10, ", ", node.speed_11)
@staticmethod
def print_node_features(node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
......@@ -493,37 +480,6 @@ class TreeObsForRailEnv(ObservationBuilder):
for direction in self.tree_explorted_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < self.observation_dim:
return
depth = 0
tmp = len(tree) / self.observation_dim - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
unfolded = {}
unfolded[''] = tree[0:self.observation_dim]
child_size = (len(tree) - self.observation_dim) // 4
for child in range(4):
child_tree = tree[(self.observation_dim + child * child_size):
(self.observation_dim + (child + 1) * child_size)]
observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
if observation_tree is not None:
if actions_for_display:
label = self.tree_explorted_actions_char[child]
else:
label = self.tree_explored_actions[child]
unfolded[label] = observation_tree
return unfolded
def set_env(self, env: Environment):
super().set_env(env)
if self.predictor:
......
......@@ -277,17 +277,17 @@ def test_shortest_path_predictor_conflicts(rendering=False):
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt)
assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_9
conflict = tree.childs[a_1].num_agents_opposite_direction
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
conflict = tree.childs[a_1].childs[a_2].num_agents_9
conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment