diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
index 52ab5fadeab17c153e5bba0414d10a2c593ad338..228d3992429cb0db4d5bd977a7c41214c3857b28 100644
--- a/flatland/envs/distance_map.py
+++ b/flatland/envs/distance_map.py
@@ -18,19 +18,18 @@ class DistanceMap:
         self.agents: List[EnvAgent] = agents
         self.rail: Optional[GridTransitionMap] = None
 
-    """
-    Set the distance map
-    """
-
     def set(self, distance_map: np.ndarray):
+        """
+        Set the distance map
+        """
         self.distance_map = distance_map
 
-    """
-    Get the distance map
-    """
 
-    def get(self) -> np.ndarray:
 
+    def get(self) -> np.ndarray:
+        """
+        Get the distance map
+        """
         if self.reset_was_called:
             self.reset_was_called = False
 
@@ -53,11 +52,12 @@ class DistanceMap:
 
         return self.distance_map
 
-    """
-    Reset the distance map
-    """
+
 
     def reset(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        """
+        Reset the distance map
+        """
         self.reset_was_called = True
         self.agents: List[EnvAgent] = agents
         self.rail = rail
diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index 200c19dc21de48559dc1e641ab01eb244208f5aa..8e3ed5b4253fef502aa5e748083854ad9a50c0bb 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -83,22 +83,26 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
     return valid_actions
 
 
-def get_shorts_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]:
+def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[RailEnvNextAction]]:
+    # TODO: do we need to support unreachable targets?
+    # TODO refactoring: unify with predictor (support agent.moving and max_depth)
     shortest_paths = dict()
     for a in distance_map.agents:
         position = a.position
         direction = a.direction
         shortest_paths[a.handle] = []
-
+        distance = math.inf
         while (position != a.target):
             next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
-            best = math.inf
 
             best_next_action = None
             for next_action in next_actions:
-                if distance_map.get()[a.handle, position[0], position[1], direction] < best:
+                next_action_distance = distance_map.get()[a.handle, next_action.next_position[0], next_action.next_position[1], next_action.next_direction]
+                if next_action_distance < distance:
                     best_next_action = next_action
+                    distance = next_action_distance
             position = best_next_action.next_position
             direction = best_next_action.next_direction
             shortest_paths[a.handle].append(best_next_action)
+
     return shortest_paths
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index b2a9b612e923e5a4290b938e2ddd723f03b3c1b1..2df60017bc5add7e90d1bbd3b07fd24299dbcc11 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
-from flatland.envs.rail_env_utils import get_shorts_paths
+from flatland.envs.rail_env_utils import get_shortest_paths
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
@@ -143,7 +143,7 @@ def test_shortest_path_predictor(rendering=False):
         1], agent.direction] == 5.0, "found {} instead of {}".format(
         distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
 
-    paths = get_shorts_paths(env.distance_map)[0]
+    paths = get_shortest_paths(env.distance_map)[0]
     assert paths == [
         RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0),
         RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0),
diff --git a/tests/test_shortest_path.py b/tests/test_shortest_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..462164367921a58a7e00eaedfe38a7670e1d7310
--- /dev/null
+++ b/tests/test_shortest_path.py
@@ -0,0 +1,52 @@
+import numpy as np
+
+from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
+from flatland.envs.rail_env_utils import load_flatland_environment_from_file, get_shortest_paths
+
+
+def test_get_shortest_paths():
+    env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
+    actual = get_shortest_paths(env.distance_map)
+
+    expected = {
+        0: [RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(1, 2), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(1, 3), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 3), next_direction=2),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 4), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 5), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 6), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 7), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 8), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 9), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 10), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 11), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 12), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 13), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 14), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 15), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 16), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 17), next_direction=1),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 18), next_direction=1)],
+        1: [RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 17), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 16), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 16), next_direction=0),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 15), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 14), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 13), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 12), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 11), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 10), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 9), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 8), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 7), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 6), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 5), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 4), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 3), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 2), next_direction=3),
+            RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(2, 1), next_direction=3)]
+    }
+
+    for agent_handle in expected:
+        assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
+            "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])