From 1a73719cff573157dfda76b61b1f5e46cc50f010 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 10 Jul 2019 13:34:52 +0200
Subject: [PATCH] #42 run baselines in ci

---
 torch_training/bla.py                  |  9 ---------
 torch_training/multi_agent_training.py | 11 ++++++++---
 tox.ini                                |  4 ++--
 3 files changed, 10 insertions(+), 14 deletions(-)
 delete mode 100644 torch_training/bla.py

diff --git a/torch_training/bla.py b/torch_training/bla.py
deleted file mode 100644
index b5d9064..0000000
--- a/torch_training/bla.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import sys
-
-print("bla")
-def main(argv):
-    print("main bla {}".format(argv))
-
-
-if __name__ == '__main__':
-    main(sys.argv[1:])
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 45f60ab..5764f86 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -27,6 +27,7 @@ def main(argv):
     for opt, arg in opts:
         if opt in ('-n', '--n_trials'):
             n_trials = int(arg)
+    print("main1")
     random.seed(1)
     np.random.seed(1)
 
@@ -42,6 +43,8 @@ def main(argv):
     n_agents = np.random.randint(3, 8)
     n_goals = n_agents + np.random.randint(0, 3)
     min_dist = int(0.75 * min(x_dim, y_dim))
+    print("main2")
+
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
@@ -51,15 +54,16 @@ def main(argv):
                   number_of_agents=n_agents)
     env.reset(True, True)
     file_load = False
-    """
-    
-    """
+
     observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
     env_renderer = RenderTool(env, gl="PILSVG", )
     handle = env.get_agent_handles()
     features_per_node = 9
     state_size = features_per_node * 85 * 2
     action_size = 5
+
+    print("main3")
+
     # We set the number of episodes we would like to train on
     if 'n_trials' not in locals():
         n_trials = 30000
@@ -216,4 +220,5 @@ def main(argv):
 
 
 if __name__ == '__main__':
+    print("main")
     main(sys.argv[1:])
diff --git a/tox.ini b/tox.ini
index d31b261..da52855 100644
--- a/tox.ini
+++ b/tox.ini
@@ -22,8 +22,8 @@ passenv =
 deps =
     -r{toxinidir}/requirements_torch_training.txt
 commands =
-;    python torch_training/multi_agent_training.py --n_trials=10
-    python torch_training/bla.py
+    python torch_training/multi_agent_training.py --n_trials=10
+
 [flake8]
 max-line-length = 120
 ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
-- 
GitLab