diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index ef6ef4e1704cb1e1c53b622d27bb48ede5af4a54..c14672819b1c0fed58705725da6dfb1feb1b9872 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -16,8 +16,8 @@ from utils.observation_utils import normalize_observation random.seed(3) np.random.seed(2) - -file_name = "./railway/simple_avoid.pkl" +""" +file_name = "./railway/complex_scene.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), @@ -41,7 +41,7 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) -""" + tree_depth = 3 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", )