diff --git a/docs/intro_observationbuilder.rst b/docs/intro_observationbuilder.rst index 563dd113e9df748a99d41dacdef19193dc0f1c01..64e953da23870d4653707c20c222c296db9b71f6 100644 --- a/docs/intro_observationbuilder.rst +++ b/docs/intro_observationbuilder.rst @@ -205,7 +205,7 @@ In contrast to the previous examples we also implement the :code:`def get_many(s :return: ''' - self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map}) + self.predictions = self.predictor.get() self.predicted_pos = {} for t in range(len(self.predictions[0])): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index baa378f1eaa148fb6546e7b207a12ce350403dfe..3f398c74468ccd69c721ac8b17e75c625221fb07 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -71,7 +71,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} - self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map}) + self.predictions = self.predictor.get() if self.predictions: for t in range(len(self.predictions[0])):