From 1b084f8fc82e5d5bce9b6492cb1bded69cf72715 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 14 Jun 2019 09:14:28 +0200
Subject: [PATCH] if agent stopps. It's predicted path is reduced to its
 current position

---
 flatland/envs/observations.py | 1 +
 flatland/envs/predictions.py  | 6 ++++--
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 931fabc7..5d879d26 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -178,6 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             self.predicted_pos = {}
             self.predicted_dir = {}
             self.predictions = self.predictor.get(self.distance_map)
+            print(self.predictions)
             for t in range(len(self.predictions[0])):
                 pos_list = []
                 dir_list = []
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index f5f559a1..cff4e00e 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -114,12 +114,14 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
             for index in range(1, self.max_depth + 1):
-                action_done = False
                 # if we're at the target, stop moving...
                 if agent.position == agent.target:
                     prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
                                          RailEnvActions.STOP_MOVING]
-
+                    continue
+                if not agent.moving:
+                    prediction[index] = [index, agent.position[0], agent.position[1], agent.direction,
+                                         RailEnvActions.STOP_MOVING]
                     continue
                 # Take shortest possible path
                 cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
-- 
GitLab