From e47b9e332534d308e7f7d68021193c6cbf460626 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 24 Sep 2019 11:54:13 +0200 Subject: [PATCH] bugfix load flatland --- flatland/envs/rail_env_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 722f54bc..dc1cff12 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -5,13 +5,15 @@ from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file -def load_flatland_environment_from_file(file_name, load_from_package=None): +def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None): + if obs_builder_object is None: + obs_builder_object = TreeObsForRailEnv( + max_depth=2, + predictor=ShortestPathPredictorForRailEnv(max_depth=10)) environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), number_of_agents=1, - schedule_generator=schedule_from_file(file_name,load_from_package), - obs_builder_object=TreeObsForRailEnv( - max_depth=2, - predictor=ShortestPathPredictorForRailEnv(max_depth=10))) + schedule_generator=schedule_from_file(file_name, load_from_package), + obs_builder_object=obs_builder_object) return environment -- GitLab