From c93e9d56267411149105622f505407e7640659ff Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Fri, 7 Jun 2019 10:56:46 +0200
Subject: [PATCH] error introduced for christian to test

---
 torch_training/training_navigation.py | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 2a4af22..198f0ee 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -65,7 +65,7 @@ action_prob = [0] * 4
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint10400.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
 demo = True
 
@@ -119,18 +119,18 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset(False,False)
-
+    obs = env.reset(False, False)
+    print(len(obs[0]))
     final_obs = obs.copy()
     final_obs_next = obs.copy()
 
     for a in range(env.get_num_agents()):
-        data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
-
+        data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
-        obs[a] = np.concatenate((data, distance))
 
+        obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+        print(len(data) + len(distance) + len(agent_data), len(obs[a]))
     for i in range(2):
         time_obs.append(obs)
     # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
@@ -156,11 +156,11 @@ for trials in range(1, n_trials + 1):
         # Environment step
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.get_num_agents()):
-            data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
+            data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7,
                                                         current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
-            next_obs[a] = np.concatenate((data, distance))
+            next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data))
 
         time_obs.append(next_obs)
 
-- 
GitLab