diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..03bbe4b71330a0d2ac2fb19e0a2afb4ae2363301
--- /dev/null
+++ b/examples/custom_observation_example.py
@@ -0,0 +1,32 @@
+import random
+
+from flatland.envs.generators import random_rail_generator, random_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+from flatland.core.env_observation_builder import ObservationBuilder
+import numpy as np
+
+random.seed(100)
+np.random.seed(100)
+
+class CustomObs(ObservationBuilder):
+    def __init__(self):
+        self.observation_space = [5]
+
+    def reset(self):
+        return
+
+    def get(self, handle):
+        observation = handle*np.ones((5,))
+        return observation
+
+env = RailEnv(width=7,
+              height=7,
+              rail_generator=random_rail_generator(),
+              number_of_agents=3,
+              obs_builder_object=CustomObs())
+
+# Print the observation vector for each agents
+obs, all_rewards, done, _ = env.step({0: 0})
+for i in range(env.get_num_agents()):
+    print("Agent ", i,"'s observation: ", obs[i])
diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..71f849de24ac62656897e065341df068b5c0f6f4
--- /dev/null
+++ b/examples/custom_railmap_example.py
@@ -0,0 +1,37 @@
+import random
+
+from flatland.envs.generators import random_rail_generator, random_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.core.transitions import RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
+from flatland.utils.rendertools import RenderTool
+import numpy as np
+
+random.seed(100)
+np.random.seed(100)
+
+def custom_rail_generator():
+    def generator(width, height, num_agents=0, num_resets=0):
+        rail_trans = RailEnvTransitions()
+        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        rail_array = grid_map.grid
+        rail_array.fill(0)
+
+        agents_positions = []
+        agents_direction = []
+        agents_target = []
+
+        return grid_map, agents_positions, agents_direction, agents_target
+    return generator
+
+env = RailEnv(width=6,
+              height=4,
+              rail_generator=custom_rail_generator(),
+              number_of_agents=1)
+
+env.reset()
+
+env_renderer = RenderTool(env, gl="QT")
+env_renderer.renderEnv(show=True)
+
+input("Press Enter to continue...")
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index e0830ff751d15d425d22fcdc6c38b5ffc68197d5..8aac0ccc97fdb619fd2feaac76fed6607dc74867 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -3,6 +3,7 @@ import random
 from flatland.envs.generators import random_rail_generator, random_rail_generator
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
+from flatland.core.env_observation_builder import ObservationBuilder
 import numpy as np
 
 random.seed(100)
@@ -11,7 +12,8 @@ np.random.seed(100)
 env = RailEnv(width=7,
               height=7,
               rail_generator=random_rail_generator(),
-              number_of_agents=2)
+              number_of_agents=2,
+              obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
 # Print the distance map of each cell to the target of the first agent
 # for i in range(4):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index af789f5d064f0e2b46b2328ebbfa55057bffdc14..7de6327121cf7a650bd36fad6faf54755a86b2d5 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -491,8 +491,14 @@ class GlobalObsForRailEnv(ObservationBuilder):
     """
 
     def __init__(self):
+        self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
+    def _set_env(self, env):
+        super()._set_env(env)
+
+        self.observation_space = [4, self.env.height, self.env.width]
+
     def reset(self):
         self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
         for i in range(self.rail_obs.shape[0]):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 4ac00f182ba75303e4a103a9bd16ee40d26ce9d0..be3f4d6bd3fb485508b8cf16823879dfafff1ae9 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -90,6 +90,9 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
+        self.action_space = [1]
+        self.observation_space = self.obs_builder.observation_space # updated on resets?
+
         self.actions = [0] * number_of_agents
         self.rewards = [0] * number_of_agents
         self.done = False
@@ -112,10 +115,6 @@ class RailEnv(Environment):
 
         self.valid_positions = None
 
-        self.action_space = [1]
-        self.observation_space = self.obs_builder.observation_space # updated on resets?
-
-
     # no more agent_handles
     def get_agent_handles(self):
         return range(self.get_num_agents())