diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 36e5305d5a5a78c600cee0ea6845c341ac2b2e6d..18e96a2b6f5c0fb2bd557989c8382a376cc71fc5 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -141,7 +141,7 @@ class ObservePredictions(TreeObsForRailEnv): :return: ''' - self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map}) + self.predictions = self.predictor.get() self.predicted_pos = {} for t in range(len(self.predictions[0])): diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index 5ce69a8110236128b9a982e2540bc79357c1ba2d..13eb38140fb25730d1817e2db7a17c350a2260c7 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -28,7 +28,7 @@ class PredictionBuilder: """ pass - def get(self, custom_args=None, handle=0): + def get(self, handle=0): """ Called whenever get_many in the observation build is called. diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 4718ad9906db9b479123b53e9e9df0ff4db3b462..dc5a13b3cf51411e7f519bf2e3ee67c30c43df2e 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -17,14 +17,12 @@ class DummyPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def get(self, custom_args=None, handle=None): + def get(self, handle=None): """ Called whenever get_many in the observation build is called. Parameters ------- - custom_args: dict - Not used in this dummy implementation. handle : int (optional) Handle of the agent for which to compute the observation vector. @@ -90,15 +88,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Initialize with depth 20 self.max_depth = max_depth - def get(self, custom_args=None, handle=None): + def get(self, handle=None): """ Called whenever get_many in the observation build is called. Requires distance_map to extract the shortest path. Parameters ------- - custom_args: dict - - distance_map : dict handle : int (optional) Handle of the agent for which to compute the observation vector. @@ -116,8 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): agents = self.env.agents if handle: agents = [self.env.agents[handle]] - assert custom_args is not None - distance_map = custom_args.get('distance_map') + distance_map = self.env.distance_map assert distance_map is not None prediction_dict = {}