diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index 29c0437e3a22c1432a410a9b892155c3a5c7cf99..e9c2a84eea5e375c024b96a35934e136bb5d40b5 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -6,8 +6,8 @@ from typing import List
 
 import numpy as np
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
@@ -18,24 +18,20 @@ random.seed(100)
 np.random.seed(100)
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index d7c4475379b26fac01918bb53e37eda066d29e87..2ed47a5f8c18894c8e0a2108b7bc0a8d73b783f9 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -6,8 +6,9 @@ from typing import Optional, List, Dict
 
 import numpy as np
 
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -20,26 +21,18 @@ random.seed(100)
 np.random.seed(100)
 
 
-class ObservePredictions(TreeObsForRailEnv):
+class ObservePredictions(ObservationBuilder):
     """
     We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
-
-    We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    This is necessary so that we can pass the distance map to the ShortestPathPredictor
-
-    Here we also want to highlight how you can visualize your observation
     """
 
     def __init__(self, predictor):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [10]
         self.predictor = predictor
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
         '''
@@ -106,6 +99,11 @@ class ObservePredictions(TreeObsForRailEnv):
 
         return observation
 
+    def set_env(self, env: Environment):
+        super().set_env(env)
+        if self.predictor:
+            self.predictor.set_env(self.env)
+
 
 def main(args):
     try:
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 9f2ee252012c52647c16e8f0b7e91cf2f9e93fbe..7cb7d9623c79d154e549114e779d39c138cf788d 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -4,8 +4,8 @@ from typing import List
 
 import numpy as np
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
@@ -15,24 +15,20 @@ random.seed(1)
 np.random.seed(1)
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index daae9b7c829f0a420cbee0d041b04aff16eb6a19..2d4df089eed08ee17f3d5f89147735b1b8570a7d 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -25,8 +25,9 @@ class ObservationBuilder:
 
     def __init__(self):
         self.observation_space = ()
+        self.env = None
 
-    def _set_env(self, env: Environment):
+    def set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -91,9 +92,6 @@ class DummyObservationBuilder(ObservationBuilder):
     def __init__(self):
         super().__init__()
 
-    def _set_env(self, env: Environment):
-        self.env = env
-
     def reset(self):
         pass
 
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index e591d6cfae141aba474eb5424446515c12df89bb..ba839c7c270c72997f4229e9690d961bee0496e3 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -19,8 +19,9 @@ class PredictionBuilder:
 
     def __init__(self, max_depth: int = 20):
         self.max_depth = max_depth
+        self.env = None
 
-    def _set_env(self, env: Environment):
+    def set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 6312cca3e51fa0e60d2493587ff5d699fcd355c4..de9ee2a45ebdeabec5202be4f12593a82b4e20e4 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -491,10 +491,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                 unfolded[label] = observation_tree
         return unfolded
 
-    def _set_env(self, env: Environment):
-        self.env = env
+    def set_env(self, env: Environment):
+        super().set_env(env)
         if self.predictor:
-            self.predictor._set_env(self.env)
+            self.predictor.set_env(self.env)
 
     def _reverse_dir(self, direction):
         return int((direction + 2) % 4)
@@ -522,9 +522,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
         self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
-    def _set_env(self, env: Environment):
-        super()._set_env(env)
-
+    def set_env(self, env: Environment):
+        super().set_env(env)
         self.observation_space = [4, self.env.height, self.env.width]
 
     def reset(self):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 1805e8c6b01a9fb88db082dc0e7de7909800c0b6..d0add3086014c7ad07c29e01588cba380c26cbd7 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -161,7 +161,7 @@ class RailEnv(Environment):
         self.rewards = [0] * number_of_agents
         self.done = False
         self.obs_builder = obs_builder_object
-        self.obs_builder._set_env(self)
+        self.obs_builder.set_env(self)
 
         self._max_episode_steps = max_episode_steps
         self._elapsed_steps = 0
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 8f7fe868f1f3bbfc766f0a1112aa9ddd384d6baa..5104b06e24a26323d798a379b65160f3a312a822 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -2,34 +2,30 @@ import random
 from typing import Dict, List
 
 import numpy as np
+from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
-from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]