diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 36e5305d5a5a78c600cee0ea6845c341ac2b2e6d..18e96a2b6f5c0fb2bd557989c8382a376cc71fc5 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -141,7 +141,7 @@ class ObservePredictions(TreeObsForRailEnv):
         :return:
         '''
 
-        self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map})
+        self.predictions = self.predictor.get()
 
         self.predicted_pos = {}
         for t in range(len(self.predictions[0])):
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index 5ce69a8110236128b9a982e2540bc79357c1ba2d..13eb38140fb25730d1817e2db7a17c350a2260c7 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -28,7 +28,7 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, custom_args=None, handle=0):
+    def get(self, handle=0):
         """
         Called whenever get_many in the observation build is called.
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 4718ad9906db9b479123b53e9e9df0ff4db3b462..dc5a13b3cf51411e7f519bf2e3ee67c30c43df2e 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -17,14 +17,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
 
         Parameters
         -------
-        custom_args: dict
-            Not used in this dummy implementation.
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
@@ -90,15 +88,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         # Initialize with depth 20
         self.max_depth = max_depth
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
 
         Parameters
         -------
-        custom_args: dict
-            - distance_map : dict
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
@@ -116,8 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
-        assert custom_args is not None
-        distance_map = custom_args.get('distance_map')
+        distance_map = self.env.distance_map
         assert distance_map is not None
 
         prediction_dict = {}