From 03aacc07d3d37bdc8a8640d67412cd1908793c70 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 13 Jun 2019 10:25:34 +0200 Subject: [PATCH] added function for converting coordinates to int and backa again (used to detect conflicts) --- flatland/envs/env_utils.py | 25 +++++++++++++++++++++++++ flatland/envs/observations.py | 9 ++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 821384d4..4d350178 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -69,6 +69,31 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p return rail_trans.is_valid(new_trans) +def position_to_coordinate(width, position): + """ + + :param width: + :param position: + :return: + """ + coords = () + for p in position: + coords = coords + ((int(p) % width, int(p) // width),) # changed x_dim to y_dim + return coords + + +def coordinate_to_position(width, coords): + """ + + :param width: + :param coords: + :return: + """ + position = [] + for t in coords: + position.append((t[1] * width + t[0])) + return np.array(position) + class AStarNode(): """A node class for A* Pathfinding""" diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 541f8ad5..8775b9d8 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -173,10 +173,13 @@ class TreeObsForRailEnv(ObservationBuilder): in the `handles' list. """ - self.predictions = [] + if self.predictor: - for a in range(len(handles)): - self.predictions.append(self.predictor.get(a)) + self.predictions = self.predictor.get() + pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0) + pred_pos = list(map(list, zip(*pred_pos))) + pred_dir = [x[:, 2] for x in list(self.predictions.values())] + observations = {} for h in handles: observations[h] = self.get(h) -- GitLab