From fe26db7d95b144d437120662437409981c659e1f Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 5 Aug 2019 16:26:38 -0400
Subject: [PATCH] Updated predictor to respect differential speed in env.

---
 examples/training_example.py |  2 +-
 flatland/envs/predictions.py | 12 +++++++-----
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index cfed6c92..521cca8c 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1):
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
-        env_renderer.render_env(show=True, show_observations=True, show_predictions=True)
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
 
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 88a79ea7..4718ad99 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         for agent in agents:
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
+            agent_speed = agent.speed_data["speed"]
+            times_per_cell = int(np.reciprocal(agent_speed))
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
+            new_direction = _agent_initial_direction
+            new_position = _agent_initial_position
             visited = set()
             for index in range(1, self.max_depth + 1):
                 # if we're at the target, stop moving...
@@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                 # Take shortest possible path
                 cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
 
-                new_position = None
-                new_direction = None
-                if np.sum(cell_transitions) == 1:
+                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
                     new_direction = np.argmax(cell_transitions)
                     new_position = get_new_position(agent.position, new_direction)
-                elif np.sum(cell_transitions) > 1:
+                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
                     min_dist = np.inf
                     no_dist_found = True
                     for direction in range(4):
@@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                                 new_direction = direction
                                 no_dist_found = False
                     new_position = get_new_position(agent.position, new_direction)
-                else:
+                elif index % times_per_cell == 0:
                     raise Exception("No transition possible {}".format(cell_transitions))
 
                 # update the agent's position and direction
-- 
GitLab