diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 0e5ad18128c115cd86b844f1eb7f0489947e8c37..dc01f7fadf21297087baa3ac0ac35513495a812a 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-
+from flatland.envs.generators import complex_rail_generator
 from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(1)
@@ -40,26 +40,26 @@ env = RailEnv(width=15,
               height=15,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
               number_of_agents=1)
-"""
+
 env = RailEnv(width=10,
               height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 env.load("./railway/complex_scene.pkl")
 file_load = True
 """
 
-env = RailEnv(width=20,
-              height=20,
-              rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
+env = RailEnv(width=10,
+              height=10,
+              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
               obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-              number_of_agents=15)
+              number_of_agents=3)
 file_load = False
 env.reset(True, True)
-
+"""
 """
 env_renderer = RenderTool(env, gl="PILSVG",)
 handle = env.get_agent_handles()
-
-state_size = 168 * 2
+features_per_node = 9
+state_size = features_per_node*21 * 2
 action_size = 5
 n_trials = 15000
 max_steps = int(3 * (env.height + env.width))
@@ -77,9 +77,9 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
+#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
-demo = True
+demo = False
 record_images = False
 
 
@@ -97,8 +97,7 @@ for trials in range(1, n_trials + 1):
     final_obs = obs.copy()
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
-        print(a)
-        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=8,
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
                                                 current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
@@ -136,7 +135,7 @@ for trials in range(1, n_trials + 1):
 
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=8,
+            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 63adfff634b58c555be064dcb13893008875290d..0c97b186a9331f185cf1a1d3f99685581cb551f7 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -48,7 +48,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
     return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
 
 
-def split_tree(tree, num_features_per_node=8, current_depth=0):
+def split_tree(tree, num_features_per_node=9, current_depth=0):
     """
     Splits the tree observation into different sub groups that need the same normalization.
     This is necessary because the tree observation includes two different distance:
@@ -80,10 +80,9 @@ def split_tree(tree, num_features_per_node=8, current_depth=0):
     Here we split the node features into the different classes of distances and binary values.
     Pay close attention to this part if you modify any of the features in the tree observation.
     """
-    tree_data = tree[:4].tolist()
-    distance_data = [tree[4]]
-    agent_data = tree[5:num_features_per_node].tolist()
-
+    tree_data = tree[:6].tolist()
+    distance_data = [tree[6]]
+    agent_data = tree[7:num_features_per_node].tolist()
     # Split each child of the current node and continue to next depth level
     for children in range(4):
         child_tree = tree[(num_features_per_node + children * child_size):