From b7d20a7e9f8e99a14cd600e976b8032d0f1d3d2c Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 10 Jul 2019 16:10:57 +0200 Subject: [PATCH] #42 run baselines in ci --- torch_training/bla.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/torch_training/bla.py b/torch_training/bla.py index a61f707..a103f9f 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -29,6 +29,32 @@ def main(argv): if opt in ('-n', '--n_trials'): n_trials = int(arg) print("main1") + random.seed(1) + np.random.seed(1) + + """ + env = RailEnv(width=10, + height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) + env.load("./railway/complex_scene.pkl") + file_load = True + """ + + x_dim = np.random.randint(8, 20) + y_dim = np.random.randint(8, 20) + n_agents = np.random.randint(3, 8) + n_goals = n_agents + np.random.randint(0, 3) + min_dist = int(0.75 * min(x_dim, y_dim)) + print("main2") + + env = RailEnv(width=x_dim, + height=y_dim, + rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, + max_dist=99999, + seed=0), + obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), + number_of_agents=n_agents) + env.reset(True, True) + file_load = False print("multi_agent_trainging.py (2)") -- GitLab