From 713db7235f4d8438d1bfa7dc04dea7cbc4108112 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 10 Jul 2019 08:52:51 +0200 Subject: [PATCH] #42 run baselines in ci --- torch_training/multi_agent_training.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 38f7c1c..45f60ab 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -1,10 +1,10 @@ import getopt +import random import sys from collections import deque import matplotlib.pyplot as plt import numpy as np -import random import torch from dueling_double_dqn import Agent from importlib_resources import path @@ -17,6 +17,7 @@ from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool from utils.observation_utils import norm_obs_clip, split_tree + def main(argv): try: opts, args = getopt.getopt(argv, "n:", ["n_trials="]) @@ -24,7 +25,7 @@ def main(argv): print('training_navigation.py -n <n_trials>') sys.exit(2) for opt, arg in opts: - if opt in ('-n','--n_trials'): + if opt in ('-n', '--n_trials'): n_trials = int(arg) random.seed(1) np.random.seed(1) @@ -83,6 +84,8 @@ def main(argv): demo = False record_images = False frame_step = 0 + + print("Going to run training for {} trials...".format(n_trials)) for trials in range(1, n_trials + 1): if trials % 50 == 0 and not demo: @@ -96,7 +99,8 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), - obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), + obs_builder_object=TreeObsForRailEnv(max_depth=3, + predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) max_steps = int(3 * (env.height + env.width)) @@ -210,5 +214,6 @@ def main(argv): plt.plot(scores) plt.show() + if __name__ == '__main__': main(sys.argv[1:]) -- GitLab