From 1a91a7383d8b7358b2ef0d080d7f62ca9a51af13 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 14 Jun 2019 00:04:06 +0200
Subject: [PATCH] Minor bugfixes in shortes path predictor

---
 examples/training_example.py  |  2 +-
 flatland/envs/observations.py |  1 -
 flatland/envs/predictions.py  | 15 ++++++++-------
 3 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index 78336b9..218f133 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -80,7 +80,7 @@ for trials in range(1, n_trials + 1):
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
-        TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
+        # TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 885a966..b5d9c2e 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -178,7 +178,6 @@ class TreeObsForRailEnv(ObservationBuilder):
             self.predicted_pos = {}
             self.predicted_dir = {}
             self.predictions = self.predictor.get(self.distance_map)
-
             for t in range(len(self.predictions[0])):
                 pos_list = []
                 dir_list = []
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 988594e..366acfe 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -127,14 +127,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                 if np.sum(cell_transitions) == 1:
                     new_direction = np.argmax(cell_transitions)
                     new_position = self._new_position(agent.position, new_direction)
-                else:
+                elif np.sum(cell_transitions) > 1:
+                    min_dist = np.inf
                     for direct in range(4):
-                        min_dist = np.inf
-                        target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
-                        if target_dist < min_dist:
-                            min_dist = target_dist
-                            new_direction = direct
-                            new_position = self._new_position(agent.position, new_direction)
+                        if cell_transitions[direct] == 1:
+                            target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
+                            if target_dist < min_dist:
+                                min_dist = target_dist
+                                new_direction = direct
+                                new_position = self._new_position(agent.position, new_direction)
 
                 agent.position = new_position
                 agent.direction = new_direction
-- 
GitLab