diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py
index 70a2515b789031bf7aec4eeb4a5e9fc495a045bf..7618720fcdc798e5801b9ddb11d50214a35b2e19 100644
--- a/examples/custom_observation_example_01_SimpleObs.py
+++ b/examples/custom_observation_example_01_SimpleObs.py
@@ -17,12 +17,13 @@ class SimpleObs(ObservationBuilder):
     """
 
     def __init__(self):
+        super().__init__()
         self.observation_space = [5]
 
     def reset(self):
         return
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         observation = handle * np.ones((5,))
         return observation
 
diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index 6c6add683bf5b7694d939bbe1a590617fb069d3e..69ae584ab6b0833c499e7fb95fc1425c39f7adaf 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -36,7 +36,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 92fbb37b2a0d8ff583b3056e418d11026f261032..77665e287dc7d3f49c38ee54ed1d5b372f1faa35 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -2,6 +2,7 @@ import getopt
 import random
 import sys
 import time
+from typing import Optional, List
 
 import numpy as np
 
@@ -40,7 +41,7 @@ class ObservePredictions(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get_many(self, handles=None):
+    def get_many(self, handles: Optional[List[int]] = None):
         '''
         Because we do not want to call the predictor seperately for every agent we implement the get_many function
         Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
@@ -51,6 +52,10 @@ class ObservePredictions(TreeObsForRailEnv):
         self.predictions = self.predictor.get()
 
         self.predicted_pos = {}
+
+        if handles is None:
+            handles = []
+
         for t in range(len(self.predictions[0])):
             pos_list = []
             for a in handles:
@@ -64,7 +69,7 @@ class ObservePredictions(TreeObsForRailEnv):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         '''
         Lets write a simple observation which just indicates whether or not the own predicted path
         overlaps with other predicted paths at any time. This is useless for the task of navigation but might
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 8aef94c23311fc693c229924953164afb5fec8ab..562091633725c7f11c87912aa76e21e0bb039a09 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -33,7 +33,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index ba79e7fc0e46d8182951b5a8e2520d7ffc9eacb0..be97a2eddb4f527d898afdcc257d445715ff8136 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -8,8 +8,12 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
 multi-agent environments.
 
 """
+from typing import Optional, List
+
 import numpy as np
 
+from flatland.core.env import Environment
+
 
 class ObservationBuilder:
     """
@@ -22,7 +26,7 @@ class ObservationBuilder:
     def __init__(self):
         self.observation_space = ()
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -31,7 +35,7 @@ class ObservationBuilder:
         """
         raise NotImplementedError()
 
-    def get_many(self, handles=[]):
+    def get_many(self, handles: Optional[List[int]] = None):
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
@@ -48,11 +52,13 @@ class ObservationBuilder:
             `handles` as keys.
         """
         observations = {}
+        if handles is None:
+            handles = []
         for h in handles:
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle=0):
+    def get(self, handle: int = 0):
         """
         Called whenever an observation has to be computed for the `env` environment, possibly
         for each agent independently (agent id `handle`).
@@ -83,16 +89,16 @@ class DummyObservationBuilder(ObservationBuilder):
     """
 
     def __init__(self):
-        self.observation_space = ()
+        super().__init__()
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
         pass
 
-    def get_many(self, handles=[]):
+    def get_many(self, handles: Optional[List[int]] = None):
         return True
 
-    def get(self, handle=0):
+    def get(self, handle: int = 0):
         return True
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index c1026fe02bfc2cfb30fd57ef570022ca3b15f6f0..e591d6cfae141aba474eb5424446515c12df89bb 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -8,6 +8,7 @@ If predictions are not required in every step or not for all agents, then
 + `get()` is called whenever an step has to be computed, potentially for each agent independently in \
 case of multi-agent environments.
 """
+from flatland.core.env import Environment
 
 
 class PredictionBuilder:
@@ -19,7 +20,7 @@ class PredictionBuilder:
     def __init__(self, max_depth: int = 20):
         self.max_depth = max_depth
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
 
     def reset(self):
@@ -28,15 +29,12 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, handle=0):
+    def get(self, handle: int = 0):
         """
         Called whenever get_many in the observation build is called.
 
         Parameters
         ----------
-        custom_args: dict
-            Implementation-dependent custom arguments, see the sub-classes.
-
         handle : int, optional
             Handle of the agent for which to compute the observation vector.
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 2a1c52207630a72f2749ba22ab7c46241839d4ab..17985c44c34e5284d33a1c513a45caa93b6efba7 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -2,10 +2,13 @@
 Collection of environment-specific ObservationBuilder.
 """
 import pprint
+from typing import Optional, List
 
 import numpy as np
 
+from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 from flatland.utils.ordered_set import OrderedSet
@@ -22,7 +25,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     For details about the features in the tree observation see the get() function.
     """
 
-    def __init__(self, max_depth, predictor=None):
+    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
         super().__init__()
         self.max_depth = max_depth
         self.observation_dim = 11
@@ -43,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     def reset(self):
         self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
 
-    def get_many(self, handles=None):
+    def get_many(self, handles: Optional[List[int]] = None):
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
@@ -72,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             observations[h] = self.get(h)
         return observations
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         """
         Computes the current observation for agent `handle` in env
 
@@ -487,7 +490,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 unfolded[label] = observation_tree
         return unfolded
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         self.env = env
         if self.predictor:
             self.predictor._set_env(self.env)
@@ -518,7 +521,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
         self.observation_space = ()
         super(GlobalObsForRailEnv, self).__init__()
 
-    def _set_env(self, env):
+    def _set_env(self, env: Environment):
         super()._set_env(env)
 
         self.observation_space = [4, self.env.height, self.env.width]
@@ -531,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         obs_targets = np.zeros((self.env.height, self.env.width, 2))
         obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
         agents = self.env.agents
@@ -597,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder):
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agents = self.env.agents
         agent = agents[handle]
 
@@ -637,13 +640,15 @@ class LocalObsForRailEnv(ObservationBuilder):
         direction = np.identity(4)[agent.direction]
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-    def get_many(self, handles=None):
+    def get_many(self, handles: Optional[List[int]] = None):
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
         """
 
         observations = {}
+        if handles is None:
+            handles = []
         for h in handles:
             observations[h] = self.get(h)
         return observations
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index ccf4b967c3eabe2ff85dc4084720aa8fc3ca9628..77707b9f110376ddf2638b830830ff1a1c1edbf6 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -18,14 +18,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, handle=None):
+    def get(self, handle: int = None):
         """
         Called whenever get_many in the observation build is called.
 
         Parameters
         ----------
-        custom_args: dict
-            Not used in this dummy implementation.
         handle : int, optional
             Handle of the agent for which to compute the observation vector.
 
@@ -87,19 +85,16 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def __init__(self, max_depth=20):
-        # Initialize with depth 20
-        self.max_depth = max_depth
+    def __init__(self, max_depth: int = 20):
+        super().__init__(max_depth)
 
-    def get(self, handle=None):
+    def get(self, handle: int = None):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
 
         Parameters
         ----------
-        custom_args: dict
-            - distance_map : dict
         handle : int, optional
             Handle of the agent for which to compute the observation vector.
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 294ffab233458f1f3b98c18be50743ba65bd2d73..c81ef9dc82df0817f3f3fc42798392d7ffdbcf5e 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -11,6 +11,7 @@ import msgpack_numpy as m
 import numpy as np
 
 from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
@@ -114,7 +115,7 @@ class RailEnv(Environment):
                  rail_generator: RailGenerator = random_rail_generator(),
                  schedule_generator: ScheduleGenerator = random_schedule_generator(),
                  number_of_agents=1,
-                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                 obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2),
                  max_episode_steps=None,
                  stochastic_data=None
                  ):
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 884a2a51f84a40a45acced32e7310dcf4d497944..55d3526757123230fb351dbf67dbfc269e58b6ac 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -32,7 +32,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
         # Recompute the distance map, if the environment has changed.
         super().reset()
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         agent = self.env.agents[handle]
 
         possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)