Skip to content
Snippets Groups Projects
Commit 1a73719c authored by u214892's avatar u214892
Browse files

#42 run baselines in ci

parent d902c726
No related branches found
No related tags found
No related merge requests found
import sys
print("bla")
def main(argv):
print("main bla {}".format(argv))
if __name__ == '__main__':
main(sys.argv[1:])
...@@ -27,6 +27,7 @@ def main(argv): ...@@ -27,6 +27,7 @@ def main(argv):
for opt, arg in opts: for opt, arg in opts:
if opt in ('-n', '--n_trials'): if opt in ('-n', '--n_trials'):
n_trials = int(arg) n_trials = int(arg)
print("main1")
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
...@@ -42,6 +43,8 @@ def main(argv): ...@@ -42,6 +43,8 @@ def main(argv):
n_agents = np.random.randint(3, 8) n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3) n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim)) min_dist = int(0.75 * min(x_dim, y_dim))
print("main2")
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
...@@ -51,15 +54,16 @@ def main(argv): ...@@ -51,15 +54,16 @@ def main(argv):
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
file_load = False file_load = False
"""
"""
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles() handle = env.get_agent_handles()
features_per_node = 9 features_per_node = 9
state_size = features_per_node * 85 * 2 state_size = features_per_node * 85 * 2
action_size = 5 action_size = 5
print("main3")
# We set the number of episodes we would like to train on # We set the number of episodes we would like to train on
if 'n_trials' not in locals(): if 'n_trials' not in locals():
n_trials = 30000 n_trials = 30000
...@@ -216,4 +220,5 @@ def main(argv): ...@@ -216,4 +220,5 @@ def main(argv):
if __name__ == '__main__': if __name__ == '__main__':
print("main")
main(sys.argv[1:]) main(sys.argv[1:])
...@@ -22,8 +22,8 @@ passenv = ...@@ -22,8 +22,8 @@ passenv =
deps = deps =
-r{toxinidir}/requirements_torch_training.txt -r{toxinidir}/requirements_torch_training.txt
commands = commands =
; python torch_training/multi_agent_training.py --n_trials=10 python torch_training/multi_agent_training.py --n_trials=10
python torch_training/bla.py
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505 ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment