diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 62cfa494072d2ac3e8d164beb54f06ef202544d0..dd81bead6bb26d40fb73987055ffb20e781f74eb 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -64,7 +64,7 @@ env = RailEnv(width=x_dim, number_of_agents=n_agents, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) -env.reset(True, True) +env.reset() env_renderer = RenderTool(env, gl="PILSVG", ) num_features_per_node = env.obs_builder.observation_dim @@ -126,10 +126,12 @@ for trials in range(1, n_trials + 1): action_dict.update({a: action}) # Environment step obs, all_rewards, done, _ = env.step(action_dict) + env_renderer.render_env(show=True, show_predictions=True, show_observations=False) # Build agent specific observations and normalize for a in range(env.get_num_agents()): - agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) + if obs[a]: + agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) if done['__all__']: diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 335849499db4b5563ea1fea6af572f0fead1ee01..daac62dbd26d190a2341d6dedc5ab0cfef1a2dff 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -117,8 +117,9 @@ def main(argv): env_renderer.reset() # Build agent specific observations for a in range(env.get_num_agents()): - agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) - agent_obs_buffer[a] = agent_obs[a].copy() + if obs[a]: + agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) + agent_obs_buffer[a] = agent_obs[a].copy() # Reset score and done score = 0 @@ -150,7 +151,8 @@ def main(argv): agent_obs_buffer[a] = agent_obs[a].copy() agent_action_buffer[a] = action_dict[a] - agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) + if next_obs[a]: + agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) score += all_rewards[a] / env.get_num_agents()