From 6dea1dff5d4ec6826dfcee57f74b8444c229e726 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 7 Jun 2019 10:34:57 +0200 Subject: [PATCH] Update split_tree in observations.py to include the agent direction and normalize the state correctly --- flatland/envs/observations.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index e6b7a4f0..5c4ed86b 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -468,7 +468,7 @@ class TreeObsForRailEnv(ObservationBuilder): prompt=prompt_[children], current_depth=current_depth + 1) - def split_tree(self, tree, num_features_per_node=5, current_depth=0): + def split_tree(self, tree, num_features_per_node=7, current_depth=0): """ :param tree: @@ -490,17 +490,19 @@ class TreeObsForRailEnv(ObservationBuilder): pow4 *= 4 child_size = (len(tree) - num_features_per_node) // 4 tree_data = tree[0:num_features_per_node - 1].tolist() - distance_data = [tree[num_features_per_node - 1]] + distance_data = [tree[num_features_per_node - 3]] + agent_data = tree[-2:] for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)] - tmp_tree_data, tmp_distance_data = self.split_tree(child_tree, + tmp_tree_data, tmp_distance_data, tmp_agent_data = self.split_tree(child_tree, num_features_per_node, current_depth=current_depth + 1) if len(tmp_tree_data) > 0: tree_data.extend(tmp_tree_data) distance_data.extend(tmp_distance_data) - return tree_data, distance_data + agent_data.extrend(tmp_agent_data) + return tree_data, distance_data, agent_data class GlobalObsForRailEnv(ObservationBuilder): -- GitLab