From 03aacc07d3d37bdc8a8640d67412cd1908793c70 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 13 Jun 2019 10:25:34 +0200
Subject: [PATCH] added function for converting coordinates to int and backa
 again (used to detect conflicts)

---
 flatland/envs/env_utils.py    | 25 +++++++++++++++++++++++++
 flatland/envs/observations.py |  9 ++++++---
 2 files changed, 31 insertions(+), 3 deletions(-)

diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index 821384d4..4d350178 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -69,6 +69,31 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
     return rail_trans.is_valid(new_trans)
 
 
+def position_to_coordinate(width, position):
+    """
+
+    :param width:
+    :param position:
+    :return:
+    """
+    coords = ()
+    for p in position:
+        coords = coords + ((int(p) % width, int(p) // width),)  # changed x_dim to y_dim
+    return coords
+
+
+def coordinate_to_position(width, coords):
+    """
+
+    :param width:
+    :param coords:
+    :return:
+    """
+    position = []
+    for t in coords:
+        position.append((t[1] * width + t[0]))
+    return np.array(position)
+
 class AStarNode():
     """A node class for A* Pathfinding"""
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 541f8ad5..8775b9d8 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -173,10 +173,13 @@ class TreeObsForRailEnv(ObservationBuilder):
         in the `handles' list.
         """
 
-        self.predictions = []
+
         if self.predictor:
-            for a in range(len(handles)):
-                self.predictions.append(self.predictor.get(a))
+            self.predictions = self.predictor.get()
+            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0)
+            pred_pos = list(map(list, zip(*pred_pos)))
+            pred_dir = [x[:, 2] for x in list(self.predictions.values())]
+
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
-- 
GitLab