Commit 6d7fbf29 authored by Erik Nygren's avatar Erik Nygren
Browse files

Minor updaes in handling the predictor.

parent 91ff7d01
......@@ -92,7 +92,7 @@ def coordinate_to_position(width, coords):
position = []
for t in coords:
position.append((t[1] * width + t[0]))
return np.array(position)
return np.asarray(position).flatten()
class AStarNode():
"""A node class for A* Pathfinding"""
......
......@@ -6,6 +6,7 @@ from collections import deque
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.env_utils import coordinate_to_position
class TreeObsForRailEnv(ObservationBuilder):
......@@ -175,11 +176,17 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.predictor:
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get()
pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0)
pred_pos = list(map(list, zip(*pred_pos)))
pred_dir = [x[:, 2] for x in list(self.predictions.values())]
for t in range(len(self.predictions[0])):
pos_list = []
dir_list = []
for a in handles:
pos_list.append(self.predictions[a][t][1:3])
dir_list.append(self.predictions[a][t][3])
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
observations = {}
for h in handles:
observations[h] = self.get(h)
......@@ -317,6 +324,9 @@ class TreeObsForRailEnv(ObservationBuilder):
# Cummulate the number of agents on branch with other direction
other_agent_opposite_direction += 1
if self.predictor:
# Register possible conflict
if position in self.location_has_target:
if num_steps < other_target_encountered:
other_target_encountered = num_steps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment