diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 81ab13e967836ca8b631b7cab5871763f2e0fce8..2ed47a5f8c18894c8e0a2108b7bc0a8d73b783f9 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -99,10 +99,10 @@ class ObservePredictions(ObservationBuilder):
 
         return observation
 
-    def _set_env(self, env: Environment):
-        self.env = env
+    def set_env(self, env: Environment):
+        super().set_env(env)
         if self.predictor:
-            self.predictor._set_env(self.env)
+            self.predictor.set_env(self.env)
 
 
 def main(args):
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index daae9b7c829f0a420cbee0d041b04aff16eb6a19..2d4df089eed08ee17f3d5f89147735b1b8570a7d 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -25,8 +25,9 @@ class ObservationBuilder:
 
     def __init__(self):
         self.observation_space = ()
+        self.env = None
 
-    def _set_env(self, env: Environment):
+    def set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder):
     def __init__(self):
         super().__init__()
 
-    def _set_env(self, env: Environment):
-        self.env = env
-
     def reset(self):
         pass
 
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index e591d6cfae141aba474eb5424446515c12df89bb..ba839c7c270c72997f4229e9690d961bee0496e3 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -19,8 +19,9 @@ class PredictionBuilder:
 
     def __init__(self, max_depth: int = 20):
         self.max_depth = max_depth
+        self.env = None
 
-    def _set_env(self, env: Environment):
+    def set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 6312cca3e51fa0e60d2493587ff5d699fcd355c4..de9ee2a45ebdeabec5202be4f12593a82b4e20e4 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                 unfolded[label] = observation_tree
         return unfolded
 
-    def _set_env(self, env: Environment):
-        self.env = env
+    def set_env(self, env: Environment):
+        super().set_env(env)
         if self.predictor:
-            self.predictor._set_env(self.env)
+            self.predictor.set_env(self.env)
 
     def _reverse_dir(self, direction):
         return int((direction + 2) % 4)
@@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
         self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
-    def _set_env(self, env: Environment):
-        super()._set_env(env)
-
+    def set_env(self, env: Environment):
+        super().set_env(env)
         self.observation_space = [4, self.env.height, self.env.width]
 
     def reset(self):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 1805e8c6b01a9fb88db082dc0e7de7909800c0b6..d0add3086014c7ad07c29e01588cba380c26cbd7 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -161,7 +161,7 @@ class RailEnv(Environment):
         self.rewards = [0] * number_of_agents
         self.done = False
         self.obs_builder = obs_builder_object
-        self.obs_builder._set_env(self)
+        self.obs_builder.set_env(self)
 
         self._max_episode_steps = max_episode_steps
         self._elapsed_steps = 0