diff --git a/torch_training/bla.py b/torch_training/bla.py deleted file mode 100644 index b5d9064f48f0c2d6c0beb72421b5c0f0788883c1..0000000000000000000000000000000000000000 --- a/torch_training/bla.py +++ /dev/null @@ -1,9 +0,0 @@ -import sys - -print("bla") -def main(argv): - print("main bla {}".format(argv)) - - -if __name__ == '__main__': - main(sys.argv[1:]) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 45f60ab0a7c97acc47c690ef75c279c655db191d..5764f865f712514e3d4cbef1648822dcc528554d 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -27,6 +27,7 @@ def main(argv): for opt, arg in opts: if opt in ('-n', '--n_trials'): n_trials = int(arg) + print("main1") random.seed(1) np.random.seed(1) @@ -42,6 +43,8 @@ def main(argv): n_agents = np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) + print("main2") + env = RailEnv(width=x_dim, height=y_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, @@ -51,15 +54,16 @@ def main(argv): number_of_agents=n_agents) env.reset(True, True) file_load = False - """ - - """ + observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() features_per_node = 9 state_size = features_per_node * 85 * 2 action_size = 5 + + print("main3") + # We set the number of episodes we would like to train on if 'n_trials' not in locals(): n_trials = 30000 @@ -216,4 +220,5 @@ def main(argv): if __name__ == '__main__': + print("main") main(sys.argv[1:]) diff --git a/tox.ini b/tox.ini index d31b261df3d41ff770be6d65bfccdf36b308d7cf..da528553c74be43bc54ce7858bb057e117cee11a 100644 --- a/tox.ini +++ b/tox.ini @@ -22,8 +22,8 @@ passenv = deps = -r{toxinidir}/requirements_torch_training.txt commands = -; python torch_training/multi_agent_training.py --n_trials=10 - python torch_training/bla.py + python torch_training/multi_agent_training.py --n_trials=10 + [flake8] max-line-length = 120 ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505