From de4b7206a3f1c6b147be6fb00c73f518509c88c7 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Mon, 16 Sep 2019 15:59:22 +0200
Subject: [PATCH] Refactoring: prediction_builder knows its environment and can
 access the distance map directly

---
 examples/custom_observation_example.py  |  2 +-
 flatland/core/env_prediction_builder.py |  2 +-
 flatland/envs/predictions.py            | 11 +++--------
 3 files changed, 5 insertions(+), 10 deletions(-)

diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 36e5305d..18e96a2b 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -141,7 +141,7 @@ class ObservePredictions(TreeObsForRailEnv):
         :return:
         '''
 
-        self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map})
+        self.predictions = self.predictor.get()
 
         self.predicted_pos = {}
         for t in range(len(self.predictions[0])):
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
index 5ce69a81..13eb3814 100644
--- a/flatland/core/env_prediction_builder.py
+++ b/flatland/core/env_prediction_builder.py
@@ -28,7 +28,7 @@ class PredictionBuilder:
         """
         pass
 
-    def get(self, custom_args=None, handle=0):
+    def get(self, handle=0):
         """
         Called whenever get_many in the observation build is called.
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 4718ad99..dc5a13b3 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -17,14 +17,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
     The prediction acts as if no other agent is in the environment and always takes the forward action.
     """
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
 
         Parameters
         -------
-        custom_args: dict
-            Not used in this dummy implementation.
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
@@ -90,15 +88,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         # Initialize with depth 20
         self.max_depth = max_depth
 
-    def get(self, custom_args=None, handle=None):
+    def get(self, handle=None):
         """
         Called whenever get_many in the observation build is called.
         Requires distance_map to extract the shortest path.
 
         Parameters
         -------
-        custom_args: dict
-            - distance_map : dict
         handle : int (optional)
             Handle of the agent for which to compute the observation vector.
 
@@ -116,8 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         agents = self.env.agents
         if handle:
             agents = [self.env.agents[handle]]
-        assert custom_args is not None
-        distance_map = custom_args.get('distance_map')
+        distance_map = self.env.distance_map
         assert distance_map is not None
 
         prediction_dict = {}
-- 
GitLab