diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4385d4da35202a658711127a719923f361552891..15d42d793b50dabe3627f6a4e601ca614f0e68b0 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -22,7 +22,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     For details about the features in the tree observation see the get() function.
     """
 
-    observation_dim = 9
+
 
     def __init__(self, max_depth, predictor=None):
         super().__init__()
@@ -34,6 +34,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i in range(self.max_depth + 1):
             size += pow4
             pow4 *= 4
+        self.observation_dim = 9
         self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}