From 6d7fbf29866e72e5d8aebc39963446ec76486be7 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 13 Jun 2019 13:36:04 +0200 Subject: [PATCH] Minor updaes in handling the predictor. --- flatland/envs/env_utils.py | 2 +- flatland/envs/observations.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 4d35017..fc2cdfb 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -92,7 +92,7 @@ def coordinate_to_position(width, coords): position = [] for t in coords: position.append((t[1] * width + t[0])) - return np.array(position) + return np.asarray(position).flatten() class AStarNode(): """A node class for A* Pathfinding""" diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8775b9d..4b598c9 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -6,6 +6,7 @@ from collections import deque import numpy as np from flatland.core.env_observation_builder import ObservationBuilder +from flatland.envs.env_utils import coordinate_to_position class TreeObsForRailEnv(ObservationBuilder): @@ -175,11 +176,17 @@ class TreeObsForRailEnv(ObservationBuilder): if self.predictor: + self.predicted_pos = {} + self.predicted_dir = {} self.predictions = self.predictor.get() - pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0) - pred_pos = list(map(list, zip(*pred_pos))) - pred_dir = [x[:, 2] for x in list(self.predictions.values())] - + for t in range(len(self.predictions[0])): + pos_list = [] + dir_list = [] + for a in handles: + pos_list.append(self.predictions[a][t][1:3]) + dir_list.append(self.predictions[a][t][3]) + self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) + self.predicted_dir.update({t: dir_list}) observations = {} for h in handles: observations[h] = self.get(h) @@ -317,6 +324,9 @@ class TreeObsForRailEnv(ObservationBuilder): # Cummulate the number of agents on branch with other direction other_agent_opposite_direction += 1 + if self.predictor: + # Register possible conflict + if position in self.location_has_target: if num_steps < other_target_encountered: other_target_encountered = num_steps -- GitLab