From 622cab240007294b181a0a941c47cd99dd6ac210 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 10 Jul 2019 16:19:44 +0200
Subject: [PATCH] #42 run baselines in ci

---
 torch_training/bla.py | 27 +++++++++++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/torch_training/bla.py b/torch_training/bla.py
index a103f9f..cd1aa13 100644
--- a/torch_training/bla.py
+++ b/torch_training/bla.py
@@ -55,6 +55,33 @@ def main(argv):
                   number_of_agents=n_agents)
     env.reset(True, True)
     file_load = False
+    observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
+    env_renderer = RenderTool(env, gl="PILSVG", )
+    handle = env.get_agent_handles()
+    features_per_node = 9
+    state_size = features_per_node * 85 * 2
+    action_size = 5
+
+    print("main3")
+
+    # We set the number of episodes we would like to train on
+    if 'n_trials' not in locals():
+        n_trials = 30000
+    max_steps = int(3 * (env.height + env.width))
+    eps = 1.
+    eps_end = 0.005
+    eps_decay = 0.9995
+    action_dict = dict()
+    final_action_dict = dict()
+    scores_window = deque(maxlen=100)
+    done_window = deque(maxlen=100)
+    time_obs = deque(maxlen=2)
+    scores = []
+    dones_list = []
+    action_prob = [0] * action_size
+    agent_obs = [None] * env.get_num_agents()
+    agent_next_obs = [None] * env.get_num_agents()
+    agent = Agent(state_size, action_size, "FC", 0)
 
 print("multi_agent_trainging.py (2)")
 
-- 
GitLab