Skip to content
Snippets Groups Projects
Commit 2e6c222d authored by u229589's avatar u229589
Browse files

add namedtuple for node in TreeObsForRailEnv

parent 60da3eb9
No related branches found
No related tags found
No related merge requests found
"""
Collection of environment-specific ObservationBuilder.
"""
import pprint
from typing import Optional, List, Dict, T, Tuple
import collections
from typing import Optional, List, Dict, Tuple
import numpy as np
......@@ -15,6 +15,19 @@ from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
Node = collections.namedtuple('Node', 'dist_1 '
'dist_2 '
'dist_3 '
'dist_4 '
'dist_5 '
'dist_6 '
'dist_7 '
'num_agents_8 '
'num_agents_9 '
'num_agents_10 '
'speed_11 '
'childs')
"""
TreeObsForRailEnv object.
......@@ -165,11 +178,15 @@ class TreeObsForRailEnv(ObservationBuilder):
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0,
dist_7=distance_map[(handle, *agent.position, agent.direction)],
num_agents_8=0, num_agents_9=0,
num_agents_10=agent.malfunction_data['malfunction'],
speed_11=agent.speed_data['speed'],
childs={})
visited = OrderedSet()
......@@ -181,19 +198,22 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
orientation = np.argmax(possible_transitions)
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
if possible_transitions[branch_direction]:
new_cell = get_new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
visited |= branch_visited
else:
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
self.env.dev_obs_dict[handle] = visited
return observation
return root_node_observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
......@@ -378,53 +398,44 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to append new / different features for each visited cell!
if last_is_target:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=0,
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
elif last_is_terminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf,
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=np.inf,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
else:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions(*position, direction)
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
......@@ -435,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(branch_direction + 2) % 4,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
elif last_is_switch and possible_transitions[branch_direction]:
......@@ -445,21 +456,43 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_direction,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
else:
# no exploring possible, add just cells with infinity
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
node.childs[self.tree_explorted_actions_char[i]] = -np.inf
return observation, visited
if depth == self.max_depth:
node.childs.clear()
return node, visited
def util_print_obs_subtree(self, tree):
def util_print_obs_subtree(self, tree: Node):
"""
Utility function to pretty-print tree observations returned by this object.
Utility function to print tree observations returned by this object.
"""
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(self.unfold_observation_tree(tree))
self.print_node_features(tree, "root", "")
for direction in self.tree_explorted_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
def print_node_features(self, node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4,
", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9,
", ", node.num_agents_10, ", ", node.speed_11)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
print(indent, "Direction ", label, ": -np.inf")
return
self.print_node_features(node, label, indent)
if not node.childs:
return
for direction in self.tree_explorted_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
......
......@@ -264,9 +264,10 @@ def test_shortest_path_predictor_conflicts(rendering=False):
# get the trees to test
obs_builder: TreeObsForRailEnv = env.obs_builder
pp = pprint.PrettyPrinter(indent=4)
tree_0 = obs_builder.unfold_observation_tree(observations[0])
tree_1 = obs_builder.unfold_observation_tree(observations[1])
pp.pprint(tree_0)
tree_0 = observations[0]
tree_1 = observations[1]
env.obs_builder.util_print_obs_subtree(tree_0)
env.obs_builder.util_print_obs_subtree(tree_1)
# check the expectations
expected_conflicts_0 = [('F', 'R')]
......@@ -275,11 +276,18 @@ def test_shortest_path_predictor_conflicts(rendering=False):
_check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][''][8]
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_9
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][a_2][''][8]
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
conflict = tree.childs[a_1].childs[a_2].num_agents_9
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment