From 713db7235f4d8438d1bfa7dc04dea7cbc4108112 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 10 Jul 2019 08:52:51 +0200
Subject: [PATCH] #42 run baselines in ci

---
 torch_training/multi_agent_training.py | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 38f7c1c..45f60ab 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -1,10 +1,10 @@
 import getopt
+import random
 import sys
 from collections import deque
 
 import matplotlib.pyplot as plt
 import numpy as np
-import random
 import torch
 from dueling_double_dqn import Agent
 from importlib_resources import path
@@ -17,6 +17,7 @@ from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 from utils.observation_utils import norm_obs_clip, split_tree
 
+
 def main(argv):
     try:
         opts, args = getopt.getopt(argv, "n:", ["n_trials="])
@@ -24,7 +25,7 @@ def main(argv):
         print('training_navigation.py -n <n_trials>')
         sys.exit(2)
     for opt, arg in opts:
-        if opt in ('-n','--n_trials'):
+        if opt in ('-n', '--n_trials'):
             n_trials = int(arg)
     random.seed(1)
     np.random.seed(1)
@@ -83,6 +84,8 @@ def main(argv):
     demo = False
     record_images = False
     frame_step = 0
+
+    print("Going to run training for {} trials...".format(n_trials))
     for trials in range(1, n_trials + 1):
 
         if trials % 50 == 0 and not demo:
@@ -96,7 +99,8 @@ def main(argv):
                           rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                                 max_dist=99999,
                                                                 seed=0),
-                          obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
+                          obs_builder_object=TreeObsForRailEnv(max_depth=3,
+                                                               predictor=ShortestPathPredictorForRailEnv()),
                           number_of_agents=n_agents)
             env.reset(True, True)
             max_steps = int(3 * (env.height + env.width))
@@ -210,5 +214,6 @@ def main(argv):
     plt.plot(scores)
     plt.show()
 
+
 if __name__ == '__main__':
     main(sys.argv[1:])
-- 
GitLab