From 3d1219e32cef517e502790e556f160b7092c103c Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 3 Jul 2019 13:03:24 -0400
Subject: [PATCH] Commented functions for better understanding

---
 utils/observation_utils.py | 25 +++++++++++++++++++------
 1 file changed, 19 insertions(+), 6 deletions(-)

diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 75e6b0a..63adfff 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -50,12 +50,19 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 
 def split_tree(tree, num_features_per_node=8, current_depth=0):
     """
-
-    :param tree:
-    :param num_features_per_node:
-    :param prompt:
-    :param current_depth:
-    :return:
+    Splits the tree observation into different sub groups that need the same normalization.
+    This is necessary because the tree observation includes two different distance:
+    1. Distance from the agent --> This is measured in cells from current agent location
+    2. Distance to targer --> This is measured as distance from cell to agent target
+    3. Binary data --> Contains information about presence of object --> No normalization necessary
+    Number 1. will depend on the depth and size of the tree search
+    Number 2. will depend on the size of the map and thus the max distance on the map
+    Number 3. Is independent of tree depth and map size and thus must be handled differently
+    Therefore we split the tree into these two classes for better normalization.
+    :param tree: Tree that needs to be split
+    :param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree.
+    :param current_depth: Keeping track of the current depth in the tree
+    :return: Returns the three different groups of distance and binary values.
     """
 
     if len(tree) < num_features_per_node:
@@ -69,9 +76,15 @@ def split_tree(tree, num_features_per_node=8, current_depth=0):
         depth += 1
         pow4 *= 4
     child_size = (len(tree) - num_features_per_node) // 4
+    """
+    Here we split the node features into the different classes of distances and binary values.
+    Pay close attention to this part if you modify any of the features in the tree observation.
+    """
     tree_data = tree[:4].tolist()
     distance_data = [tree[4]]
     agent_data = tree[5:num_features_per_node].tolist()
+
+    # Split each child of the current node and continue to next depth level
     for children in range(4):
         child_tree = tree[(num_features_per_node + children * child_size):
                           (num_features_per_node + (children + 1) * child_size)]
-- 
GitLab