diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 75e6b0ad2a0e6f82fbf62a15fbd25f7d6c95537e..63adfff634b58c555be064dcb13893008875290d 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -50,12 +50,19 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 
 def split_tree(tree, num_features_per_node=8, current_depth=0):
     """
-
-    :param tree:
-    :param num_features_per_node:
-    :param prompt:
-    :param current_depth:
-    :return:
+    Splits the tree observation into different sub groups that need the same normalization.
+    This is necessary because the tree observation includes two different distance:
+    1. Distance from the agent --> This is measured in cells from current agent location
+    2. Distance to targer --> This is measured as distance from cell to agent target
+    3. Binary data --> Contains information about presence of object --> No normalization necessary
+    Number 1. will depend on the depth and size of the tree search
+    Number 2. will depend on the size of the map and thus the max distance on the map
+    Number 3. Is independent of tree depth and map size and thus must be handled differently
+    Therefore we split the tree into these two classes for better normalization.
+    :param tree: Tree that needs to be split
+    :param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree.
+    :param current_depth: Keeping track of the current depth in the tree
+    :return: Returns the three different groups of distance and binary values.
     """
 
     if len(tree) < num_features_per_node:
@@ -69,9 +76,15 @@ def split_tree(tree, num_features_per_node=8, current_depth=0):
         depth += 1
         pow4 *= 4
     child_size = (len(tree) - num_features_per_node) // 4
+    """
+    Here we split the node features into the different classes of distances and binary values.
+    Pay close attention to this part if you modify any of the features in the tree observation.
+    """
     tree_data = tree[:4].tolist()
     distance_data = [tree[4]]
     agent_data = tree[5:num_features_per_node].tolist()
+
+    # Split each child of the current node and continue to next depth level
     for children in range(4):
         child_tree = tree[(num_features_per_node + children * child_size):
                           (num_features_per_node + (children + 1) * child_size)]