Skip to content
Snippets Groups Projects
Commit 6dea1dff authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Update split_tree in observations.py to include the agent direction and...

Update split_tree in observations.py to include the agent direction and normalize the state correctly
parent 2de92fce
No related branches found
No related tags found
No related merge requests found
......@@ -468,7 +468,7 @@ class TreeObsForRailEnv(ObservationBuilder):
prompt=prompt_[children],
current_depth=current_depth + 1)
def split_tree(self, tree, num_features_per_node=5, current_depth=0):
def split_tree(self, tree, num_features_per_node=7, current_depth=0):
"""
:param tree:
......@@ -490,17 +490,19 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4
tree_data = tree[0:num_features_per_node - 1].tolist()
distance_data = [tree[num_features_per_node - 1]]
distance_data = [tree[num_features_per_node - 3]]
agent_data = tree[-2:]
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data = self.split_tree(child_tree,
tmp_tree_data, tmp_distance_data, tmp_agent_data = self.split_tree(child_tree,
num_features_per_node,
current_depth=current_depth + 1)
if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data)
distance_data.extend(tmp_distance_data)
return tree_data, distance_data
agent_data.extrend(tmp_agent_data)
return tree_data, distance_data, agent_data
class GlobalObsForRailEnv(ObservationBuilder):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment