diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 977a74a136d1f606db8d4687c47549b41d2e6185..4385d4da35202a658711127a719923f361552891 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -22,7 +22,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     For details about the features in the tree observation see the get() function.
     """
 
-
+    observation_dim = 9
 
     def __init__(self, max_depth, predictor=None):
         super().__init__()
@@ -43,9 +43,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
         self.distance_map = None
 
-        # this needs to be updated when new features are added!
-        self.observation_dim = 9
-
     def reset(self):
         agents = self.env.agents
         nb_agents = len(agents)