From 653126fe437e72bb8425bf533494ccb104197a1b Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 24 Oct 2019 17:12:19 -0400
Subject: [PATCH] catchin error with None observation

---
 torch_training/render_agent_behavior.py | 6 ++++--
 torch_training/training_navigation.py   | 8 +++++---
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 62cfa49..dd81bea 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -64,7 +64,7 @@ env = RailEnv(width=x_dim,
               number_of_agents=n_agents,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=TreeObservation)
-env.reset(True, True)
+env.reset()
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 num_features_per_node = env.obs_builder.observation_dim
@@ -126,10 +126,12 @@ for trials in range(1, n_trials + 1):
             action_dict.update({a: action})
         # Environment step
         obs, all_rewards, done, _ = env.step(action_dict)
+
         env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
         # Build agent specific observations and normalize
         for a in range(env.get_num_agents()):
-            agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
+            if obs[a]:
+                agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
 
 
         if done['__all__']:
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 3358494..daac62d 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -117,8 +117,9 @@ def main(argv):
         env_renderer.reset()
         # Build agent specific observations
         for a in range(env.get_num_agents()):
-            agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
-            agent_obs_buffer[a] = agent_obs[a].copy()
+            if obs[a]:
+                agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
+                agent_obs_buffer[a] = agent_obs[a].copy()
 
         # Reset score and done
         score = 0
@@ -150,7 +151,8 @@ def main(argv):
 
                     agent_obs_buffer[a] = agent_obs[a].copy()
                     agent_action_buffer[a] = action_dict[a]
-                agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
+                if next_obs[a]:
+                    agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
 
                 score += all_rewards[a] / env.get_num_agents()
 
-- 
GitLab