diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 2a1c52207630a72f2749ba22ab7c46241839d4ab..31ad16438cbf63a532b8e90bdbd9efb704e755b7 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -6,6 +6,7 @@ import pprint
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 from flatland.utils.ordered_set import OrderedSet
@@ -22,7 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     For details about the features in the tree observation see the get() function.
     """
 
-    def __init__(self, max_depth, predictor=None):
+    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
         super().__init__()
         self.max_depth = max_depth
         self.observation_dim = 11