diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4385d4da35202a658711127a719923f361552891..977a74a136d1f606db8d4687c47549b41d2e6185 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -22,7 +22,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     For details about the features in the tree observation see the get() function.
     """
 
-    observation_dim = 9
+
 
     def __init__(self, max_depth, predictor=None):
         super().__init__()
@@ -43,6 +43,9 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
         self.distance_map = None
 
+        # this needs to be updated when new features are added!
+        self.observation_dim = 9
+
     def reset(self):
         agents = self.env.agents
         nb_agents = len(agents)