From 3d1219e32cef517e502790e556f160b7092c103c Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 3 Jul 2019 13:03:24 -0400 Subject: [PATCH] Commented functions for better understanding --- utils/observation_utils.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/utils/observation_utils.py b/utils/observation_utils.py index 75e6b0a..63adfff 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -50,12 +50,19 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): def split_tree(tree, num_features_per_node=8, current_depth=0): """ - - :param tree: - :param num_features_per_node: - :param prompt: - :param current_depth: - :return: + Splits the tree observation into different sub groups that need the same normalization. + This is necessary because the tree observation includes two different distance: + 1. Distance from the agent --> This is measured in cells from current agent location + 2. Distance to targer --> This is measured as distance from cell to agent target + 3. Binary data --> Contains information about presence of object --> No normalization necessary + Number 1. will depend on the depth and size of the tree search + Number 2. will depend on the size of the map and thus the max distance on the map + Number 3. Is independent of tree depth and map size and thus must be handled differently + Therefore we split the tree into these two classes for better normalization. + :param tree: Tree that needs to be split + :param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree. + :param current_depth: Keeping track of the current depth in the tree + :return: Returns the three different groups of distance and binary values. """ if len(tree) < num_features_per_node: @@ -69,9 +76,15 @@ def split_tree(tree, num_features_per_node=8, current_depth=0): depth += 1 pow4 *= 4 child_size = (len(tree) - num_features_per_node) // 4 + """ + Here we split the node features into the different classes of distances and binary values. + Pay close attention to this part if you modify any of the features in the tree observation. + """ tree_data = tree[:4].tolist() distance_data = [tree[4]] agent_data = tree[5:num_features_per_node].tolist() + + # Split each child of the current node and continue to next depth level for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)] -- GitLab