Skip to content
Snippets Groups Projects
Commit 1a73719c authored by u214892's avatar u214892
Browse files

#42 run baselines in ci

parent d902c726
No related branches found
No related tags found
No related merge requests found
import sys
print("bla")
def main(argv):
print("main bla {}".format(argv))
if __name__ == '__main__':
main(sys.argv[1:])
......@@ -27,6 +27,7 @@ def main(argv):
for opt, arg in opts:
if opt in ('-n', '--n_trials'):
n_trials = int(arg)
print("main1")
random.seed(1)
np.random.seed(1)
......@@ -42,6 +43,8 @@ def main(argv):
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
print("main2")
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
......@@ -51,15 +54,16 @@ def main(argv):
number_of_agents=n_agents)
env.reset(True, True)
file_load = False
"""
"""
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
features_per_node = 9
state_size = features_per_node * 85 * 2
action_size = 5
print("main3")
# We set the number of episodes we would like to train on
if 'n_trials' not in locals():
n_trials = 30000
......@@ -216,4 +220,5 @@ def main(argv):
if __name__ == '__main__':
print("main")
main(sys.argv[1:])
......@@ -22,8 +22,8 @@ passenv =
deps =
-r{toxinidir}/requirements_torch_training.txt
commands =
; python torch_training/multi_agent_training.py --n_trials=10
python torch_training/bla.py
python torch_training/multi_agent_training.py --n_trials=10
[flake8]
max-line-length = 120
ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment