Skip to content
Snippets Groups Projects
Commit 735c8624 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '56-bugfix-tree-observation-builder' into 'master'

#56 bugfix dimensions in explore branch

Closes #56

See merge request flatland/flatland!53
parents 00201b71 cf930ee2
No related branches found
No related tags found
No related merge requests found
...@@ -437,7 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -437,7 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth - depth): for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4 num_cells_to_fill_in += pow4
pow4 *= 4 pow4 *= 4
observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
return observation, visited return observation, visited
...@@ -496,8 +496,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -496,8 +496,8 @@ class TreeObsForRailEnv(ObservationBuilder):
child_tree = tree[(num_features_per_node + children * child_size): child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)] (num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data, tmp_agent_data = self.split_tree(child_tree, tmp_tree_data, tmp_distance_data, tmp_agent_data = self.split_tree(child_tree,
num_features_per_node, num_features_per_node,
current_depth=current_depth + 1) current_depth=current_depth + 1)
if len(tmp_tree_data) > 0: if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data) tree_data.extend(tmp_tree_data)
distance_data.extend(tmp_distance_data) distance_data.extend(tmp_distance_data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment