diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index 722f54bcb5332174199aab070b02308802619bb1..dc1cff12c0c8b1860859208a13d6403734a2d2ad 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -5,13 +5,15 @@ from flatland.envs.rail_generators import rail_from_file
 from flatland.envs.schedule_generators import schedule_from_file
 
 
-def load_flatland_environment_from_file(file_name, load_from_package=None):
+def load_flatland_environment_from_file(file_name, load_from_package=None, obs_builder_object=None):
+    if obs_builder_object is None:
+        obs_builder_object = TreeObsForRailEnv(
+            max_depth=2,
+            predictor=ShortestPathPredictorForRailEnv(max_depth=10))
     environment = RailEnv(width=1,
                           height=1,
                           rail_generator=rail_from_file(file_name, load_from_package),
                           number_of_agents=1,
-                          schedule_generator=schedule_from_file(file_name,load_from_package),
-                          obs_builder_object=TreeObsForRailEnv(
-                              max_depth=2,
-                              predictor=ShortestPathPredictorForRailEnv(max_depth=10)))
+                          schedule_generator=schedule_from_file(file_name, load_from_package),
+                          obs_builder_object=obs_builder_object)
     return environment