Commit c4747283 authored by Erik Nygren's avatar Erik Nygren
Browse files

moved split_tree into baseline repo

parent 2de81f41
Pipeline #1316 passed with stages
in 10 minutes and 41 seconds
......@@ -557,42 +557,6 @@ class TreeObsForRailEnv(ObservationBuilder):
prompt=prompt_[children],
current_depth=current_depth + 1)
def split_tree(self, tree, num_features_per_node=8, current_depth=0):
"""
:param tree:
:param num_features_per_node:
:param prompt:
:param current_depth:
:return:
"""
if len(tree) < num_features_per_node:
return [], [], []
depth = 0
tmp = len(tree) / num_features_per_node - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4
tree_data = tree[:4].tolist()
distance_data = [tree[4]]
agent_data = tree[5:num_features_per_node].tolist()
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data, tmp_agent_data = self.split_tree(child_tree,
num_features_per_node,
current_depth=current_depth + 1)
if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data)
distance_data.extend(tmp_distance_data)
agent_data.extend(tmp_agent_data)
return tree_data, distance_data, agent_data
def _set_env(self, env):
self.env = env
if self.predictor:
......
......@@ -111,7 +111,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.6.5"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment