Commit 6af2e060 authored by Erik Nygren's avatar Erik Nygren
Browse files

Updated prediction in observation builder

parent d77ecfd1
Pipeline #1060 failed with stages
in 8 minutes and 8 seconds
......@@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.agents_previous_reset = None
def reset(self):
......@@ -175,8 +174,10 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
# TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
self.predictions = []
if self.predictor:
print(self.predictor.get(0))
for a in range(len(handles)):
self.predictions.append(self.predictor.get(a))
observations = {}
for h in handles:
observations[h] = self.get(h)
......@@ -222,6 +223,8 @@ class TreeObsForRailEnv(ObservationBuilder):
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
#8: possible conflict detected
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
......@@ -256,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder):
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
observation = observation + branch_observation
......@@ -294,6 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder):
other_target_encountered = np.inf
other_agent_same_direction = 0
other_agent_opposite_direction = 0
possible_conflict = 0
num_steps = 1
while exploring:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment