diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 30c0fabeae0d7f9dfd49ba05983ad90f457cd5f2..a08a5f73cee9dd0c389597cc4abb4dfaefd52cc0 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -58,7 +58,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     def reset(self):
         self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
-    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]:
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
@@ -87,7 +87,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle: int = 0) -> List[int]:
+    def get(self, handle: int = 0) -> Node:
         """
         Computes the current observation for agent `handle` in env