diff --git a/examples/training_example.py b/examples/training_example.py
index 78336b96e2b4e9ed7da1be3253d83c769a9a21f4..218f133efe4e79064ae94b24c61aa2d6f2ce2d09 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -80,7 +80,7 @@ for trials in range(1, n_trials + 1):
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
-        TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
+        # TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 885a966b87e23e6294b090fa7c9338ff3893cb60..b5d9c2e0ba6db708c394aa99ca73f1e42bb4cb43 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -178,7 +178,6 @@ class TreeObsForRailEnv(ObservationBuilder):
             self.predicted_pos = {}
             self.predicted_dir = {}
             self.predictions = self.predictor.get(self.distance_map)
-
             for t in range(len(self.predictions[0])):
                 pos_list = []
                 dir_list = []
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 988594e63ad72b895127934258be19c1de85d0b6..366acfebeb8c5ba4f56138c9bab0bcbb904a7b11 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -127,14 +127,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                 if np.sum(cell_transitions) == 1:
                     new_direction = np.argmax(cell_transitions)
                     new_position = self._new_position(agent.position, new_direction)
-                else:
+                elif np.sum(cell_transitions) > 1:
+                    min_dist = np.inf
                     for direct in range(4):
-                        min_dist = np.inf
-                        target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
-                        if target_dist < min_dist:
-                            min_dist = target_dist
-                            new_direction = direct
-                            new_position = self._new_position(agent.position, new_direction)
+                        if cell_transitions[direct] == 1:
+                            target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
+                            if target_dist < min_dist:
+                                min_dist = target_dist
+                                new_direction = direct
+                                new_position = self._new_position(agent.position, new_direction)
 
                 agent.position = new_position
                 agent.direction = new_direction