diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 1d760c6ee60c62286aa5589ff358449daa5f8ed5..bc377e393e51ac17dba9d198ef44565e15c1c062 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -43,6 +43,7 @@ env.reset(True, True)
 """
 observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG", )
+num_features_per_node = env.obs_builder.observation_dim
 handle = env.get_agent_handles()
 features_per_node = 9
 state_size = features_per_node * 85 * 2
@@ -82,7 +83,7 @@ for trials in range(1, n_trials + 1):
     final_obs = obs.copy()
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
-        data, distance, agent_data = split_tree(tree=np.array(obs[a]),
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node,
                                                 current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
@@ -118,6 +119,7 @@ for trials in range(1, n_trials + 1):
         obs_original = next_obs.copy()
         for a in range(env.get_num_agents()):
             data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                    num_features_per_node=num_features_per_node,
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)