diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index c1026fe02bfc2cfb30fd57ef570022ca3b15f6f0..756068efffa995968322bfecfefee178717a226b 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -8,6 +8,7 @@ If predictions are not required in every step or not for all agents, then
 + `get()` is called whenever an step has to be computed, potentially for each agent independently in \
 case of multi-agent environments.
 """
+from flatland.core.env import Environment
 
 
 class PredictionBuilder:
@@ -19,7 +20,7 @@ class PredictionBuilder:
     def __init__(self, max_depth: int = 20):
         self.max_depth = max_depth
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -28,7 +29,7 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, handle=0):
+    def get(self, handle: int = 0):
         """
         Called whenever get_many in the observation build is called.
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index ccf4b967c3eabe2ff85dc4084720aa8fc3ca9628..0b390f01c7d8ef876d3b30e58e0fd48f5aceecdf 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, handle=None):
+    def get(self, handle: int = None):
         """
         Called whenever get_many in the observation build is called.
 
@@ -91,7 +91,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         # Initialize with depth 20
         self.max_depth = max_depth
 
-    def get(self, handle=None):
+    def get(self, handle: int = None):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.