From 1a73719cff573157dfda76b61b1f5e46cc50f010 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 10 Jul 2019 13:34:52 +0200 Subject: [PATCH] #42 run baselines in ci --- torch_training/bla.py | 9 --------- torch_training/multi_agent_training.py | 11 ++++++++--- tox.ini | 4 ++-- 3 files changed, 10 insertions(+), 14 deletions(-) delete mode 100644 torch_training/bla.py diff --git a/torch_training/bla.py b/torch_training/bla.py deleted file mode 100644 index b5d9064..0000000 --- a/torch_training/bla.py +++ /dev/null @@ -1,9 +0,0 @@ -import sys - -print("bla") -def main(argv): - print("main bla {}".format(argv)) - - -if __name__ == '__main__': - main(sys.argv[1:]) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 45f60ab..5764f86 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -27,6 +27,7 @@ def main(argv): for opt, arg in opts: if opt in ('-n', '--n_trials'): n_trials = int(arg) + print("main1") random.seed(1) np.random.seed(1) @@ -42,6 +43,8 @@ def main(argv): n_agents = np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) + print("main2") + env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, @@ -51,15 +54,16 @@ def main(argv): number_of_agents=n_agents) env.reset(True, True) file_load = False - """ - - """ + observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() features_per_node = 9 state_size = features_per_node * 85 * 2 action_size = 5 + + print("main3") + # We set the number of episodes we would like to train on if 'n_trials' not in locals(): n_trials = 30000 @@ -216,4 +220,5 @@ def main(argv): if __name__ == '__main__': + print("main") main(sys.argv[1:]) diff --git a/tox.ini b/tox.ini index d31b261..da52855 100644 --- a/tox.ini +++ b/tox.ini @@ -22,8 +22,8 @@ passenv = deps = -r{toxinidir}/requirements_torch_training.txt commands = -; python torch_training/multi_agent_training.py --n_trials=10 - python torch_training/bla.py + python torch_training/multi_agent_training.py --n_trials=10 + [flake8] max-line-length = 120 ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505 -- GitLab