From 4fc55df4e1fd124eb03e3b471898e186e2760955 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Mon, 15 Jul 2019 15:02:52 -0400 Subject: [PATCH] updated agent behavior rendering code --- torch_training/render_agent_behavior.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 1d760c6..bc377e3 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -43,6 +43,7 @@ env.reset(True, True) """ observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) +num_features_per_node = env.obs_builder.observation_dim handle = env.get_agent_handles() features_per_node = 9 state_size = features_per_node * 85 * 2 @@ -82,7 +83,7 @@ for trials in range(1, n_trials + 1): final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(obs[a]), + data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) @@ -118,6 +119,7 @@ for trials in range(1, n_trials + 1): obs_original = next_obs.copy() for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + num_features_per_node=num_features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) -- GitLab