From de4b7206a3f1c6b147be6fb00c73f518509c88c7 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Mon, 16 Sep 2019 15:59:22 +0200 Subject: [PATCH] Refactoring: prediction_builder knows its environment and can access the distance map directly --- examples/custom_observation_example.py | 2 +- flatland/core/env_prediction_builder.py | 2 +- flatland/envs/predictions.py | 11 +++-------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 36e5305d..18e96a2b 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -141,7 +141,7 @@ class ObservePredictions(TreeObsForRailEnv): :return: ''' - self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map}) + self.predictions = self.predictor.get() self.predicted_pos = {} for t in range(len(self.predictions[0])): diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index 5ce69a81..13eb3814 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -28,7 +28,7 @@ class PredictionBuilder: """ pass - def get(self, custom_args=None, handle=0): + def get(self, handle=0): """ Called whenever get_many in the observation build is called. diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 4718ad99..dc5a13b3 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -17,14 +17,12 @@ class DummyPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def get(self, custom_args=None, handle=None): + def get(self, handle=None): """ Called whenever get_many in the observation build is called. Parameters ------- - custom_args: dict - Not used in this dummy implementation. handle : int (optional) Handle of the agent for which to compute the observation vector. @@ -90,15 +88,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Initialize with depth 20 self.max_depth = max_depth - def get(self, custom_args=None, handle=None): + def get(self, handle=None): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. Parameters ------- - custom_args: dict - - distance_map : dict handle : int (optional) Handle of the agent for which to compute the observation vector. @@ -116,8 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): agents = self.env.agents if handle: agents = [self.env.agents[handle]] - assert custom_args is not None - distance_map = custom_args.get('distance_map') + distance_map = self.env.distance_map assert distance_map is not None prediction_dict = {} -- GitLab