diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index de9ee2a45ebdeabec5202be4f12593a82b4e20e4..483ae3053a7884f84668c653d6672ce33982b8b7 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -1,8 +1,8 @@ """ Collection of environment-specific ObservationBuilder. """ -import pprint -from typing import Optional, List, Dict, T, Tuple +import collections +from typing import Optional, List, Dict, Tuple import numpy as np @@ -15,6 +15,19 @@ from flatland.utils.ordered_set import OrderedSet class TreeObsForRailEnv(ObservationBuilder): + + Node = collections.namedtuple('Node', 'dist_1 ' + 'dist_2 ' + 'dist_3 ' + 'dist_4 ' + 'dist_5 ' + 'dist_6 ' + 'dist_7 ' + 'num_agents_8 ' + 'num_agents_9 ' + 'num_agents_10 ' + 'speed_11 ' + 'childs') """ TreeObsForRailEnv object. @@ -165,11 +178,15 @@ class TreeObsForRailEnv(ObservationBuilder): possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) - # Root node - current position # Here information about the agent itself is stored distance_map = self.env.distance_map.get() - observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0, - agent.malfunction_data['malfunction'], agent.speed_data['speed']] + + root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0, + dist_7=distance_map[(handle, *agent.position, agent.direction)], + num_agents_8=0, num_agents_9=0, + num_agents_10=agent.malfunction_data['malfunction'], + speed_11=agent.speed_data['speed'], + childs={}) visited = OrderedSet() @@ -181,19 +198,22 @@ class TreeObsForRailEnv(ObservationBuilder): if num_transitions == 1: orientation = np.argmax(possible_transitions) - for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: + for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): + if possible_transitions[branch_direction]: new_cell = get_new_position(agent.position, branch_direction) + branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) - observation = observation + branch_observation + root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation + visited |= branch_visited else: # add cells filled with infinity if no transition is possible - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) + root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf self.env.dev_obs_dict[handle] = visited - return observation + return root_node_observation def _num_cells_to_fill_in(self, remaining_depth): """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" @@ -378,53 +398,44 @@ class TreeObsForRailEnv(ObservationBuilder): # Modify here to append new / different features for each visited cell! if last_is_target: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - 0, - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] + node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered, + dist_3=other_agent_encountered, dist_4=potential_conflict, + dist_5=unusable_switch, dist_6=tot_dist, + dist_7=0, + num_agents_8=other_agent_same_direction, + num_agents_9=other_agent_opposite_direction, + num_agents_10=malfunctioning_agent, + speed_11=min_fractional_speed, + childs={}) elif last_is_terminal: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - np.inf, - self.env.distance_map.get()[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] + node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered, + dist_3=other_agent_encountered, dist_4=potential_conflict, + dist_5=unusable_switch, dist_6=np.inf, + dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction], + num_agents_8=other_agent_same_direction, + num_agents_9=other_agent_opposite_direction, + num_agents_10=malfunctioning_agent, + speed_11=min_fractional_speed, + childs={}) else: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - self.env.distance_map.get()[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] + node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered, + dist_3=other_agent_encountered, dist_4=potential_conflict, + dist_5=unusable_switch, dist_6=tot_dist, + dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction], + num_agents_8=other_agent_same_direction, + num_agents_9=other_agent_opposite_direction, + num_agents_10=malfunctioning_agent, + speed_11=min_fractional_speed, + childs={}) # ############################# # ############################# # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions possible_transitions = self.env.rail.get_transitions(*position, direction) - for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: + for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): if last_is_dead_end and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes @@ -435,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4, tot_dist + 1, depth + 1) - observation = observation + branch_observation + node.childs[self.tree_explorted_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: @@ -445,21 +456,43 @@ class TreeObsForRailEnv(ObservationBuilder): branch_direction, tot_dist + 1, depth + 1) - observation = observation + branch_observation + node.childs[self.tree_explorted_actions_char[i]] = branch_observation if len(branch_visited) != 0: visited |= branch_visited else: # no exploring possible, add just cells with infinity - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) + node.childs[self.tree_explorted_actions_char[i]] = -np.inf - return observation, visited + if depth == self.max_depth: + node.childs.clear() + return node, visited - def util_print_obs_subtree(self, tree): + def util_print_obs_subtree(self, tree: Node): """ - Utility function to pretty-print tree observations returned by this object. + Utility function to print tree observations returned by this object. """ - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(self.unfold_observation_tree(tree)) + self.print_node_features(tree, "root", "") + for direction in self.tree_explorted_actions_char: + self.print_subtree(tree.childs[direction], direction, "\t") + + def print_node_features(self, node: Node, label, indent): + print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4, + ", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9, + ", ", node.num_agents_10, ", ", node.speed_11) + + def print_subtree(self, node, label, indent): + if node == -np.inf or not node: + print(indent, "Direction ", label, ": -np.inf") + return + + self.print_node_features(node, label, indent) + + if not node.childs: + return + + for direction in self.tree_explorted_actions_char: + self.print_subtree(node.childs[direction], direction, indent + "\t") + def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): """ diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index c31494673e63a17dc07eb6d89eeb581c640b1e13..7c5e685f3940e0e81859d02dac3752608b5771b7 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -264,9 +264,10 @@ def test_shortest_path_predictor_conflicts(rendering=False): # get the trees to test obs_builder: TreeObsForRailEnv = env.obs_builder pp = pprint.PrettyPrinter(indent=4) - tree_0 = obs_builder.unfold_observation_tree(observations[0]) - tree_1 = obs_builder.unfold_observation_tree(observations[1]) - pp.pprint(tree_0) + tree_0 = observations[0] + tree_1 = observations[1] + env.obs_builder.util_print_obs_subtree(tree_0) + env.obs_builder.util_print_obs_subtree(tree_1) # check the expectations expected_conflicts_0 = [('F', 'R')] @@ -275,11 +276,18 @@ def test_shortest_path_predictor_conflicts(rendering=False): _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ") -def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''): - assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt) +def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''): + assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt) for a_1 in obs_builder.tree_explorted_actions_char: - conflict = tree_0[a_1][''][8] - assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) + if tree.childs[a_1] == -np.inf: + assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) + continue + else: + conflict = tree.childs[a_1].num_agents_9 + assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) for a_2 in obs_builder.tree_explorted_actions_char: - conflict = tree_0[a_1][a_2][''][8] - assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) + if tree.childs[a_1].childs[a_2] == -np.inf: + assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) + else: + conflict = tree.childs[a_1].childs[a_2].num_agents_9 + assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)