diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 483ae3053a7884f84668c653d6672ce33982b8b7..30c0fabeae0d7f9dfd49ba05983ad90f457cd5f2 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -16,17 +16,17 @@ from flatland.utils.ordered_set import OrderedSet class TreeObsForRailEnv(ObservationBuilder): - Node = collections.namedtuple('Node', 'dist_1 ' - 'dist_2 ' - 'dist_3 ' - 'dist_4 ' - 'dist_5 ' - 'dist_6 ' - 'dist_7 ' - 'num_agents_8 ' - 'num_agents_9 ' - 'num_agents_10 ' - 'speed_11 ' + Node = collections.namedtuple('Node', 'dist_own_target_encountered ' + 'dist_other_target_encountered ' + 'dist_other_agent_encountered ' + 'dist_potential_conflict ' + 'dist_unusable_switch ' + 'dist_to_next_branch ' + 'dist_min_to_target ' + 'num_agents_same_direction ' + 'num_agents_opposite_direction ' + 'num_agents_malfunctioning ' + 'speed_min_fractional ' 'childs') """ TreeObsForRailEnv object. @@ -53,7 +53,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent_direction = {} self.predictor = predictor self.location_has_target = None - self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] def reset(self): @@ -181,12 +180,15 @@ class TreeObsForRailEnv(ObservationBuilder): # Here information about the agent itself is stored distance_map = self.env.distance_map.get() - root_node_observation = TreeObsForRailEnv.Node(dist_1=0, dist_2=0, dist_3=0, dist_4=0, dist_5=0, dist_6=0, - dist_7=distance_map[(handle, *agent.position, agent.direction)], - num_agents_8=0, num_agents_9=0, - num_agents_10=agent.malfunction_data['malfunction'], - speed_11=agent.speed_data['speed'], - childs={}) + root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0, + dist_other_agent_encountered=0, dist_potential_conflict=0, + dist_unusable_switch=0, dist_to_next_branch=0, + dist_min_to_target=distance_map[(handle, *agent.position, + agent.direction)], + num_agents_same_direction=0, num_agents_opposite_direction=0, + num_agents_malfunctioning=agent.malfunction_data['malfunction'], + speed_min_fractional=agent.speed_data['speed'], + childs={}) visited = OrderedSet() @@ -215,15 +217,6 @@ class TreeObsForRailEnv(ObservationBuilder): return root_node_observation - def _num_cells_to_fill_in(self, remaining_depth): - """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" - num_observations = 0 - pow4 = 1 - for i in range(remaining_depth): - num_observations += pow4 - pow4 *= 4 - return num_observations * self.observation_dim - def _explore_branch(self, handle, position, direction, tot_dist, depth): """ Utility function to compute tree-based observations. @@ -398,37 +391,28 @@ class TreeObsForRailEnv(ObservationBuilder): # Modify here to append new / different features for each visited cell! if last_is_target: - node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered, - dist_3=other_agent_encountered, dist_4=potential_conflict, - dist_5=unusable_switch, dist_6=tot_dist, - dist_7=0, - num_agents_8=other_agent_same_direction, - num_agents_9=other_agent_opposite_direction, - num_agents_10=malfunctioning_agent, - speed_11=min_fractional_speed, - childs={}) - + dist_to_next_branch = tot_dist, + dist_min_to_target = 0, elif last_is_terminal: - node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered, - dist_3=other_agent_encountered, dist_4=potential_conflict, - dist_5=unusable_switch, dist_6=np.inf, - dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction], - num_agents_8=other_agent_same_direction, - num_agents_9=other_agent_opposite_direction, - num_agents_10=malfunctioning_agent, - speed_11=min_fractional_speed, - childs={}) - + dist_to_next_branch = np.inf, + dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction], else: - node = TreeObsForRailEnv.Node(dist_1=own_target_encountered, dist_2=other_target_encountered, - dist_3=other_agent_encountered, dist_4=potential_conflict, - dist_5=unusable_switch, dist_6=tot_dist, - dist_7=self.env.distance_map.get()[handle, position[0], position[1], direction], - num_agents_8=other_agent_same_direction, - num_agents_9=other_agent_opposite_direction, - num_agents_10=malfunctioning_agent, - speed_11=min_fractional_speed, - childs={}) + dist_to_next_branch = tot_dist, + dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction], + + node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered, + dist_other_target_encountered=other_target_encountered, + dist_other_agent_encountered=other_agent_encountered, + dist_potential_conflict=potential_conflict, + dist_unusable_switch=unusable_switch, + dist_to_next_branch=dist_to_next_branch, + dist_min_to_target=dist_min_to_target, + num_agents_same_direction=other_agent_same_direction, + num_agents_opposite_direction=other_agent_opposite_direction, + num_agents_malfunctioning=malfunctioning_agent, + speed_min_fractional=min_fractional_speed, + childs={}) + # ############################# # ############################# # Start from the current orientation, and see which transitions are available; @@ -475,10 +459,13 @@ class TreeObsForRailEnv(ObservationBuilder): for direction in self.tree_explorted_actions_char: self.print_subtree(tree.childs[direction], direction, "\t") - def print_node_features(self, node: Node, label, indent): - print(indent, "Direction ", label, ": ", node.dist_1, ", ", node.dist_2, ", ", node.dist_3, ", ", node.dist_4, - ", ", node.dist_5, ", ", node.dist_6, ", ", node.dist_7, ", ", node.num_agents_8, ", ", node.num_agents_9, - ", ", node.num_agents_10, ", ", node.speed_11) + @staticmethod + def print_node_features(node: Node, label, indent): + print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ", + node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ", + node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ", + node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction, + ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional) def print_subtree(self, node, label, indent): if node == -np.inf or not node: @@ -493,37 +480,6 @@ class TreeObsForRailEnv(ObservationBuilder): for direction in self.tree_explorted_actions_char: self.print_subtree(node.childs[direction], direction, indent + "\t") - - def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): - """ - Utility function to pretty-print tree observations returned by this object. - """ - if len(tree) < self.observation_dim: - return - - depth = 0 - tmp = len(tree) / self.observation_dim - 1 - pow4 = 4 - while tmp > 0: - tmp -= pow4 - depth += 1 - pow4 *= 4 - - unfolded = {} - unfolded[''] = tree[0:self.observation_dim] - child_size = (len(tree) - self.observation_dim) // 4 - for child in range(4): - child_tree = tree[(self.observation_dim + child * child_size): - (self.observation_dim + (child + 1) * child_size)] - observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1) - if observation_tree is not None: - if actions_for_display: - label = self.tree_explorted_actions_char[child] - else: - label = self.tree_explored_actions[child] - unfolded[label] = observation_tree - return unfolded - def set_env(self, env: Environment): super().set_env(env) if self.predictor: diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 7c5e685f3940e0e81859d02dac3752608b5771b7..9ef122e6aa8325d3ec70307332cc7235090270de 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -277,17 +277,17 @@ def test_shortest_path_predictor_conflicts(rendering=False): def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''): - assert (tree.num_agents_9 > 0) == (() in expected_conflicts), "{}[]".format(prompt) + assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt) for a_1 in obs_builder.tree_explorted_actions_char: if tree.childs[a_1] == -np.inf: assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) continue else: - conflict = tree.childs[a_1].num_agents_9 + conflict = tree.childs[a_1].num_agents_opposite_direction assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) for a_2 in obs_builder.tree_explorted_actions_char: if tree.childs[a_1].childs[a_2] == -np.inf: assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) else: - conflict = tree.childs[a_1].childs[a_2].num_agents_9 + conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)