diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 1d760c6ee60c62286aa5589ff358449daa5f8ed5..bc377e393e51ac17dba9d198ef44565e15c1c062 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -43,6 +43,7 @@ env.reset(True, True) """ observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) +num_features_per_node = env.obs_builder.observation_dim handle = env.get_agent_handles() features_per_node = 9 state_size = features_per_node * 85 * 2 @@ -82,7 +83,7 @@ for trials in range(1, n_trials + 1): final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(obs[a]), + data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) @@ -118,6 +119,7 @@ for trials in range(1, n_trials + 1): obs_original = next_obs.copy() for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + num_features_per_node=num_features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance)