From b7d20a7e9f8e99a14cd600e976b8032d0f1d3d2c Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 10 Jul 2019 16:10:57 +0200
Subject: [PATCH] #42 run baselines in ci

---
 torch_training/bla.py | 26 ++++++++++++++++++++++++++
 1 file changed, 26 insertions(+)

diff --git a/torch_training/bla.py b/torch_training/bla.py
index a61f707..a103f9f 100644
--- a/torch_training/bla.py
+++ b/torch_training/bla.py
@@ -29,6 +29,32 @@ def main(argv):
         if opt in ('-n', '--n_trials'):
             n_trials = int(arg)
     print("main1")
+    random.seed(1)
+    np.random.seed(1)
+
+    """
+    env = RailEnv(width=10,
+                  height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
+    env.load("./railway/complex_scene.pkl")
+    file_load = True
+    """
+
+    x_dim = np.random.randint(8, 20)
+    y_dim = np.random.randint(8, 20)
+    n_agents = np.random.randint(3, 8)
+    n_goals = n_agents + np.random.randint(0, 3)
+    min_dist = int(0.75 * min(x_dim, y_dim))
+    print("main2")
+
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                                                        max_dist=99999,
+                                                        seed=0),
+                  obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
+                  number_of_agents=n_agents)
+    env.reset(True, True)
+    file_load = False
 
 print("multi_agent_trainging.py (2)")
 
-- 
GitLab