From eb97e701210d3be483ee751dfb5ef623a78b9c40 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 11 Jul 2019 08:41:50 +0200
Subject: [PATCH] #42 run baselines in ci

---
 torch_training/bla.py | 38 +++++++++++++++++++-------------------
 1 file changed, 19 insertions(+), 19 deletions(-)

diff --git a/torch_training/bla.py b/torch_training/bla.py
index ec81697..f76c4ab 100644
--- a/torch_training/bla.py
+++ b/torch_training/bla.py
@@ -117,25 +117,25 @@ def main(argv):
             obs = env.reset(True, True)
         if demo:
             env_renderer.set_new_rail()
-        # obs_original = obs.copy()
-        # final_obs = obs.copy()
-        # final_obs_next = obs.copy()
-        # for a in range(env.get_num_agents()):
-        #     data, distance, agent_data = split_tree(tree=np.array(obs[a]),
-        #                                             current_depth=0)
-        #     data = norm_obs_clip(data)
-        #     distance = norm_obs_clip(distance)
-        #     agent_data = np.clip(agent_data, -1, 1)
-        #     obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
-        #     agent_data = env.agents[a]
-        #     speed = 1  # np.random.randint(1,5)
-        #     agent_data.speed_data['speed'] = 1. / speed
-        #
-        # for i in range(2):
-        #     time_obs.append(obs)
-        # # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
-        # for a in range(env.get_num_agents()):
-        #     agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+        obs_original = obs.copy()
+        final_obs = obs.copy()
+        final_obs_next = obs.copy()
+        for a in range(env.get_num_agents()):
+            data, distance, agent_data = split_tree(tree=np.array(obs[a]),
+                                                    current_depth=0)
+            data = norm_obs_clip(data)
+            distance = norm_obs_clip(distance)
+            agent_data = np.clip(agent_data, -1, 1)
+            obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+            agent_data = env.agents[a]
+            speed = 1  # np.random.randint(1,5)
+            agent_data.speed_data['speed'] = 1. / speed
+
+        for i in range(2):
+            time_obs.append(obs)
+        # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+        for a in range(env.get_num_agents()):
+            agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
         #
         # score = 0
         # env_done = 0
-- 
GitLab