Skip to content
Snippets Groups Projects
Commit 6af2e060 authored by Erik Nygren's avatar Erik Nygren
Browse files

Updated prediction in observation builder

parent d77ecfd1
No related branches found
No related tags found
No related merge requests found
...@@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent = {} self.location_has_agent = {}
self.location_has_agent_direction = {} self.location_has_agent_direction = {}
self.predictor = predictor self.predictor = predictor
self.agents_previous_reset = None self.agents_previous_reset = None
def reset(self): def reset(self):
...@@ -175,8 +174,10 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -175,8 +174,10 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
# TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
self.predictions = []
if self.predictor: if self.predictor:
print(self.predictor.get(0)) for a in range(len(handles)):
self.predictions.append(self.predictor.get(a))
observations = {} observations = {}
for h in handles: for h in handles:
observations[h] = self.get(h) observations[h] = self.get(h)
...@@ -222,6 +223,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -222,6 +223,8 @@ class TreeObsForRailEnv(ObservationBuilder):
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts) (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself 0 = no agent present other direction than myself
#8: possible conflict detected
Missing/padding nodes are filled in with -inf (truncated). Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated). Missing values in present node are filled in with +inf (truncated).
...@@ -256,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -256,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder):
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]: if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction) new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \ branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
observation = observation + branch_observation observation = observation + branch_observation
...@@ -294,6 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -294,6 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder):
other_target_encountered = np.inf other_target_encountered = np.inf
other_agent_same_direction = 0 other_agent_same_direction = 0
other_agent_opposite_direction = 0 other_agent_opposite_direction = 0
possible_conflict = 0
num_steps = 1 num_steps = 1
while exploring: while exploring:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment