diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 722f54bcb5332174199aab070b02308802619bb1..dc1cff12c0c8b1860859208a13d6403734a2d2ad 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -5,13 +5,15 @@ from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file -def load_flatland_environment_from_file(file_name, load_from_package=None): +def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None): + if obs_builder_object is None: + obs_builder_object = TreeObsForRailEnv( + max_depth=2, + predictor=ShortestPathPredictorForRailEnv(max_depth=10)) environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), number_of_agents=1, - schedule_generator=schedule_from_file(file_name,load_from_package), - obs_builder_object=TreeObsForRailEnv( - max_depth=2, - predictor=ShortestPathPredictorForRailEnv(max_depth=10))) + schedule_generator=schedule_from_file(file_name, load_from_package), + obs_builder_object=obs_builder_object) return environment