diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index d06506383544df5ef03b4128c37213d0a1a2ac69..989800044250d60b68c68b5d2e702b5625964024 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -1,9 +1,10 @@
 import numpy as np
-from flatland.envs.generators import complex_rail_generator, random_rail_generator
-from flatland.envs.rail_env import RailEnv
 from ray.rllib.env.multi_agent_env import MultiAgentEnv
 from ray.rllib.utils.seed import seed as set_seed
 
+from flatland.envs.generators import complex_rail_generator, random_rail_generator
+from flatland.envs.rail_env import RailEnv
+
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
 
@@ -63,7 +64,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
 
         for i_agent in range(len(self.env.agents)):
             data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
-                                                                         num_features_per_node=8, current_depth=0)
+                                                                         current_depth=0)
             o[i_agent] = [data, distance, agent_data]
 
         # needed for the renderer
@@ -72,8 +73,6 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         self.agents_static = self.env.agents_static
         self.dev_obs_dict = self.env.dev_obs_dict
 
-
-
         # If step_memory > 1, we need to concatenate it the observations in memory, only works for
         # step_memory = 1 or 2 for the moment
         if self.step_memory < 2:
@@ -96,7 +95,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         for i_agent in range(len(self.env.agents)):
             if i_agent not in self.agents_done:
                 data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
-                                                                             num_features_per_node=8, current_depth=0)
+                                                                             current_depth=0)
 
                 o[i_agent] = [data, distance, agent_data]
                 r[i_agent] = rewards[i_agent]
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index a7104d35255b8b96e0e001f8973c109914fece14..7b9470d3b980657073af02e87b92edfea98bd879 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -100,7 +100,7 @@ for trials in range(1, n_trials + 1):
     final_obs = obs.copy()
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
-        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]),
                                                 current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
@@ -141,7 +141,7 @@ for trials in range(1, n_trials + 1):
         # print(all_rewards,action)
         obs_original = next_obs.copy()
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
+            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 8410bf209d3519e214d7a2cf84283e767c3aca9b..88aa57edcd7b611c3be5de83604339c91f38202e 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -1,13 +1,10 @@
 from collections import deque
+
 import matplotlib.pyplot as plt
 import numpy as np
 import random
-
 import torch
 from dueling_double_dqn import Agent
-from importlib_resources import path
-
-import torch_training.Nets
 
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
@@ -91,7 +88,6 @@ for trials in range(1, n_trials + 1):
     # Build agent specific local observation
     for a in range(env.get_num_agents()):
         rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
-                                                          num_features_per_node=features_per_node,
                                                           current_depth=0)
         rail_data = norm_obs_clip(rail_data)
         distance_data = norm_obs_clip(distance_data)
@@ -125,7 +121,6 @@ for trials in range(1, n_trials + 1):
 
         for a in range(env.get_num_agents()):
             rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                              num_features_per_node=features_per_node,
                                                               current_depth=0)
             rail_data = norm_obs_clip(rail_data)
             distance_data = norm_obs_clip(distance_data)
@@ -185,7 +180,7 @@ env_renderer.set_new_rail()
 # Split the observation tree into its parts and normalize the observation using the utility functions.
 # Build agent specific local observation
 for a in range(env.get_num_agents()):
-    rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
+    rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
                                                       current_depth=0)
     rail_data = norm_obs_clip(rail_data)
     distance_data = norm_obs_clip(distance_data)
@@ -211,7 +206,6 @@ for step in range(max_steps):
 
     for a in range(env.get_num_agents()):
         rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                          num_features_per_node=features_per_node,
                                                           current_depth=0)
         rail_data = norm_obs_clip(rail_data)
         distance_data = norm_obs_clip(distance_data)
diff --git a/utils/misc_utils.py b/utils/misc_utils.py
index 03c9fdde9368bf324f7e10841b2d30b993858fd6..5b29c6b15f61b46062bac8d4fb6c4130fe61c6ec 100644
--- a/utils/misc_utils.py
+++ b/utils/misc_utils.py
@@ -101,7 +101,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
         lp_reset(True, True)
         obs = env.reset(True, True)
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9,
+            data, distance, agent_data = split_tree(tree=np.array(obs[a]),
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
@@ -127,7 +127,6 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
 
             for a in range(env.get_num_agents()):
                 data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                        num_features_per_node=features_per_node,
                                                         current_depth=0)
                 data = norm_obs_clip(data)
                 distance = norm_obs_clip(distance)