From 6d7fbf29866e72e5d8aebc39963446ec76486be7 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 13 Jun 2019 13:36:04 +0200
Subject: [PATCH] Minor updaes in handling the predictor.

---
 flatland/envs/env_utils.py    |  2 +-
 flatland/envs/observations.py | 18 ++++++++++++++----
 2 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index 4d35017..fc2cdfb 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -92,7 +92,7 @@ def coordinate_to_position(width, coords):
     position = []
     for t in coords:
         position.append((t[1] * width + t[0]))
-    return np.array(position)
+    return np.asarray(position).flatten()
 
 class AStarNode():
     """A node class for A* Pathfinding"""
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 8775b9d..4b598c9 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -6,6 +6,7 @@ from collections import deque
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.envs.env_utils import coordinate_to_position
 
 
 class TreeObsForRailEnv(ObservationBuilder):
@@ -175,11 +176,17 @@ class TreeObsForRailEnv(ObservationBuilder):
 
 
         if self.predictor:
+            self.predicted_pos = {}
+            self.predicted_dir = {}
             self.predictions = self.predictor.get()
-            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0)
-            pred_pos = list(map(list, zip(*pred_pos)))
-            pred_dir = [x[:, 2] for x in list(self.predictions.values())]
-
+            for t in range(len(self.predictions[0])):
+                pos_list = []
+                dir_list = []
+                for a in handles:
+                    pos_list.append(self.predictions[a][t][1:3])
+                    dir_list.append(self.predictions[a][t][3])
+                self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
+                self.predicted_dir.update({t: dir_list})
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
@@ -317,6 +324,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                     # Cummulate the number of agents on branch with other direction
                     other_agent_opposite_direction += 1
 
+            if self.predictor:
+            # Register possible conflict
+
             if position in self.location_has_target:
                 if num_steps < other_target_encountered:
                     other_target_encountered = num_steps
-- 
GitLab