From 653126fe437e72bb8425bf533494ccb104197a1b Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 24 Oct 2019 17:12:19 -0400 Subject: [PATCH] catchin error with None observation --- torch_training/render_agent_behavior.py | 6 ++++-- torch_training/training_navigation.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 62cfa49..dd81bea 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -64,7 +64,7 @@ env = RailEnv(width=x_dim, number_of_agents=n_agents, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) -env.reset(True, True) +env.reset() env_renderer = RenderTool(env, gl="PILSVG", ) num_features_per_node = env.obs_builder.observation_dim @@ -126,10 +126,12 @@ for trials in range(1, n_trials + 1): action_dict.update({a: action}) # Environment step obs, all_rewards, done, _ = env.step(action_dict) + env_renderer.render_env(show=True, show_predictions=True, show_observations=False) # Build agent specific observations and normalize for a in range(env.get_num_agents()): - agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) + if obs[a]: + agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) if done['__all__']: diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 3358494..daac62d 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -117,8 +117,9 @@ def main(argv): env_renderer.reset() # Build agent specific observations for a in range(env.get_num_agents()): - agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) - agent_obs_buffer[a] = agent_obs[a].copy() + if obs[a]: + agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) + agent_obs_buffer[a] = agent_obs[a].copy() # Reset score and done score = 0 @@ -150,7 +151,8 @@ def main(argv): agent_obs_buffer[a] = agent_obs[a].copy() agent_action_buffer[a] = action_dict[a] - agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) + if next_obs[a]: + agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) score += all_rewards[a] / env.get_num_agents() -- GitLab