Commit 2e6c222d authored by u229589's avatar u229589
Browse files

add namedtuple for node in TreeObsForRailEnv

parent 60da3eb9
"""
Collection of environment-specific ObservationBuilder.
"""
import pprint
from typing import Optional, List, Dict, T, Tuple
import collections
from typing import Optional, List, Dict, Tuple
import numpy as np
......@@ -15,6 +15,19 @@ from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
Node = collections.namedtuple('Node', 'dist_1 '
'dist_2 '
'dist_3 '
'dist_4 '
'dist_5 '
'dist_6 '
'dist_7 '
'num_agents_8 '
'num_agents_9 '
'num_agents_10 '
'speed_11 '
'childs')
"""
TreeObsForRailEnv object.
......@@ -165,11 +178,15 @@ class TreeObsForRailEnv(ObservationBuilder):
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0,
dist_7=distance_map[(handle, *agent.position, agent.direction)],
num_agents_8=0, num_agents_9=0,
num_agents_10=agent.malfunction_data['malfunction'],
speed_11=agent.speed_data['speed'],
childs={})
visited = OrderedSet()
......@@ -181,19 +198,22 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
orientation = np.argmax(possible_transitions)
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
if possible_transitions[branch_direction]:
new_cell = get_new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
visited |= branch_visited
else:
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
self.env.dev_obs_dict[handle] = visited
return observation
return root_node_observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
......@@ -378,53 +398,44 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to append new / different features for each visited cell!
if last_is_target:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=0,
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
elif last_is_terminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf,
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=np.inf,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
else:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
self.env.distance_map.get()[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
]
node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered,
dist_3=other_agent_encountered, dist_4=potential_conflict,
dist_5=unusable_switch, dist_6=tot_dist,
dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction],
num_agents_8=other_agent_same_direction,
num_agents_9=other_agent_opposite_direction,
num_agents_10=malfunctioning_agent,
speed_11=min_fractional_speed,
childs={})
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions(*position, direction)
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
......@@ -435,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(branch_direction + 2) % 4,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
elif last_is_switch and possible_transitions[branch_direction]:
......@@ -445,21 +456,43 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_direction,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
node.childs[self.tree_explorted_actions_char[i]] = branch_observation
if len(branch_visited) != 0:
visited |= branch_visited
else:
# no exploring possible, add just cells with infinity
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
node.childs[self.tree_explorted_actions_char[i]] = -np.inf
return observation, visited
if depth == self.max_depth:
node.childs.clear()
return node, visited
def util_print_obs_subtree(self, tree):
def util_print_obs_subtree(self, tree: Node):
"""
Utility function to pretty-print tree observations returned by this object.
Utility function to print tree observations returned by this object.
"""
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(self.unfold_observation_tree(tree))
self.print_node_features(tree, "root", "")
for direction in self.tree_explorted_actions_char:
self.print_subtree(tree.childs[direction], direction, "\t")
def print_node_features(self, node: Node, label, indent):
print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4,
", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9,
", ", node.num_agents_10, ", ", node.speed_11)
def print_subtree(self, node, label, indent):
if node == -np.inf or not node:
print(indent, "Direction ", label, ": -np.inf")
return
self.print_node_features(node, label, indent)
if not node.childs:
return
for direction in self.tree_explorted_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
......
......@@ -264,9 +264,10 @@ def test_shortest_path_predictor_conflicts(rendering=False):
# get the trees to test
obs_builder: TreeObsForRailEnv = env.obs_builder
pp = pprint.PrettyPrinter(indent=4)
tree_0 = obs_builder.unfold_observation_tree(observations[0])
tree_1 = obs_builder.unfold_observation_tree(observations[1])
pp.pprint(tree_0)
tree_0 = observations[0]
tree_1 = observations[1]
env.obs_builder.util_print_obs_subtree(tree_0)
env.obs_builder.util_print_obs_subtree(tree_1)
# check the expectations
expected_conflicts_0 = [('F', 'R')]
......@@ -275,11 +276,18 @@ def test_shortest_path_predictor_conflicts(rendering=False):
_check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][''][8]
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
if tree.childs[a_1] == -np.inf:
assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
continue
else:
conflict = tree.childs[a_1].num_agents_9
assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
for a_2 in obs_builder.tree_explorted_actions_char:
conflict = tree_0[a_1][a_2][''][8]
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
if tree.childs[a_1].childs[a_2] == -np.inf:
assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
else:
conflict = tree.childs[a_1].childs[a_2].num_agents_9
assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment