Skip to content
Snippets Groups Projects
Commit de4b7206 authored by u229589's avatar u229589
Browse files

Refactoring: prediction_builder knows its environment and can access the distance map directly

parent 5f4773a1
No related branches found
No related tags found
No related merge requests found
......@@ -141,7 +141,7 @@ class ObservePredictions(TreeObsForRailEnv):
:return:
'''
self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map})
self.predictions = self.predictor.get()
self.predicted_pos = {}
for t in range(len(self.predictions[0])):
......
......@@ -28,7 +28,7 @@ class PredictionBuilder:
"""
pass
def get(self, custom_args=None, handle=0):
def get(self, handle=0):
"""
Called whenever get_many in the observation build is called.
......
......@@ -17,14 +17,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, custom_args=None, handle=None):
def get(self, handle=None):
"""
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
......@@ -90,15 +88,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# Initialize with depth 20
self.max_depth = max_depth
def get(self, custom_args=None, handle=None):
def get(self, handle=None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters
-------
custom_args: dict
- distance_map : dict
handle : int (optional)
Handle of the agent for which to compute the observation vector.
......@@ -116,8 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
assert custom_args is not None
distance_map = custom_args.get('distance_map')
distance_map = self.env.distance_map
assert distance_map is not None
prediction_dict = {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment