diff --git a/run.py b/run.py index 09dc04dae944bc23aaeeeb70e0789a605eb31a4a..7ad08cd45122f68de9da711022269f4255727a69 100644 --- a/run.py +++ b/run.py @@ -9,6 +9,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked from utils.fast_tree_obs import FastTreeObs @@ -81,6 +82,8 @@ while True: nb_agents = len(local_env.agents) max_nb_steps = local_env._max_episode_steps + policy = DeadLockAvoidanceAgent(local_env) + tree_observation.set_env(local_env) tree_observation.reset() observation = tree_observation.get_many(list(range(nb_agents)))