diff --git a/utils/observation_utils.py b/utils/observation_utils.py index 75e6b0ad2a0e6f82fbf62a15fbd25f7d6c95537e..63adfff634b58c555be064dcb13893008875290d 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -50,12 +50,19 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): def split_tree(tree, num_features_per_node=8, current_depth=0): """ - - :param tree: - :param num_features_per_node: - :param prompt: - :param current_depth: - :return: + Splits the tree observation into different sub groups that need the same normalization. + This is necessary because the tree observation includes two different distance: + 1. Distance from the agent --> This is measured in cells from current agent location + 2. Distance to targer --> This is measured as distance from cell to agent target + 3. Binary data --> Contains information about presence of object --> No normalization necessary + Number 1. will depend on the depth and size of the tree search + Number 2. will depend on the size of the map and thus the max distance on the map + Number 3. Is independent of tree depth and map size and thus must be handled differently + Therefore we split the tree into these two classes for better normalization. + :param tree: Tree that needs to be split + :param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree. + :param current_depth: Keeping track of the current depth in the tree + :return: Returns the three different groups of distance and binary values. """ if len(tree) < num_features_per_node: @@ -69,9 +76,15 @@ def split_tree(tree, num_features_per_node=8, current_depth=0): depth += 1 pow4 *= 4 child_size = (len(tree) - num_features_per_node) // 4 + """ + Here we split the node features into the different classes of distances and binary values. + Pay close attention to this part if you modify any of the features in the tree observation. + """ tree_data = tree[:4].tolist() distance_data = [tree[4]] agent_data = tree[5:num_features_per_node].tolist() + + # Split each child of the current node and continue to next depth level for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)]