diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index 29c0437e3a22c1432a410a9b892155c3a5c7cf99..e9c2a84eea5e375c024b96a35934e136bb5d40b5 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -6,8 +6,8 @@ from typing import List
 
 import numpy as np
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
@@ -18,24 +18,20 @@ random.seed(100)
 np.random.seed(100)
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index d7c4475379b26fac01918bb53e37eda066d29e87..81ab13e967836ca8b631b7cab5871763f2e0fce8 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -6,8 +6,9 @@ from typing import Optional, List, Dict
 
 import numpy as np
 
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -20,26 +21,18 @@ random.seed(100)
 np.random.seed(100)
 
 
-class ObservePredictions(TreeObsForRailEnv):
+class ObservePredictions(ObservationBuilder):
     """
     We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
-
-    We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    This is necessary so that we can pass the distance map to the ShortestPathPredictor
-
-    Here we also want to highlight how you can visualize your observation
     """
 
     def __init__(self, predictor):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [10]
         self.predictor = predictor
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
         '''
@@ -106,6 +99,11 @@ class ObservePredictions(TreeObsForRailEnv):
 
         return observation
 
+    def _set_env(self, env: Environment):
+        self.env = env
+        if self.predictor:
+            self.predictor._set_env(self.env)
+
 
 def main(args):
     try:
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 9f2ee252012c52647c16e8f0b7e91cf2f9e93fbe..7cb7d9623c79d154e549114e779d39c138cf788d 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -4,8 +4,8 @@ from typing import List
 
 import numpy as np
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
@@ -15,24 +15,20 @@ random.seed(1)
 np.random.seed(1)
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 8f7fe868f1f3bbfc766f0a1112aa9ddd384d6baa..5104b06e24a26323d798a379b65160f3a312a822 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -2,34 +2,30 @@ import random
 from typing import Dict, List
 
 import numpy as np
+from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
-from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 
-class SingleAgentNavigationObs(TreeObsForRailEnv):
+class SingleAgentNavigationObs(ObservationBuilder):
     """
-    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
-    the minimum distances from each grid node to each agent's target.
-
-    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    We build a representation vector with 3 binary components, indicating which of the 3 available directions
     for each agent (Left, Forward, Right) lead to the shortest path to its target.
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
 
     def __init__(self):
-        super().__init__(max_depth=0)
+        super().__init__()
         self.observation_space = [3]
 
     def reset(self):
-        # Recompute the distance map, if the environment has changed.
-        super().reset()
+        pass
 
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]