diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 1f02d518a9714de91d8910b8cb1408f25eb3fe88..18af8a0c88067ee00f9d9b2fb6f8bbbecc486909 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -209,8 +209,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
 
         #2: if another agents target is detected the distance in number of cells from the agents current locaiton
-        is stored
-
+            is stored
 
         #3: if another agent is detected the distance in number of cells from current agent position is stored.
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 671b349a4794f565b40e0d085393af6b92a08989..3dbd163ca2d67ffa698a5fe2c931ce98c2dd8bbc 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -140,16 +140,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                     new_position = get_new_position(agent.position, new_direction)
                 elif np.sum(cell_transitions) > 1:
                     min_dist = np.inf
+                    no_dist_found = True
                     for direction in range(4):
                         if cell_transitions[direction] == 1:
                             neighbour_cell = get_new_position(agent.position, direction)
                             target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
-                            if target_dist < min_dist:
+                            if target_dist < min_dist or no_dist_found:
                                 min_dist = target_dist
                                 new_direction = direction
-                    if new_direction == None:
-                        prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
-                        continue
+                                no_dist_found = False
                     new_position = get_new_position(agent.position, new_direction)
                 else:
                     raise Exception("No transition possible {}".format(cell_transitions))
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index f5f46408875e6235493682014d8bc4313ad5ea34..8abfd1b3b295bbe347ff5a25a5cd685314e1621a 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -249,10 +249,10 @@ class RailEnv(Environment):
             action_selected = False
             if agent.speed_data['position_fraction'] == 0.:
                 if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
-                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
+                    cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
                         self._check_action_on_agent(action, agent)
 
-                    if all([new_cell_isValid, transition_isValid]):
+                    if all([new_cell_valid, transition_valid]):
                         agent.speed_data['transition_action_on_cellexit'] = action
                         action_selected = True
 
@@ -260,10 +260,10 @@ class RailEnv(Environment):
                         # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                         # try to keep moving forward!
                         if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
-                            cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
+                            cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
                                 self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
 
-                            if all([new_cell_isValid, transition_isValid]):
+                            if all([new_cell_valid, transition_valid]):
                                 agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                                 action_selected = True
 
@@ -271,17 +271,15 @@ class RailEnv(Environment):
                                 # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
                                 self.rewards_dict[i_agent] += invalid_action_penalty
                                 self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
-                                agent.moving = False
                                 self.rewards_dict[i_agent] += stop_penalty
-
+                                agent.moving = False
                                 continue
                         else:
                             # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
                             self.rewards_dict[i_agent] += invalid_action_penalty
                             self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
-                            agent.moving = False
                             self.rewards_dict[i_agent] += stop_penalty
-
+                            agent.moving = False
                             continue
 
             if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0):
@@ -293,10 +291,10 @@ class RailEnv(Environment):
 
                 # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
                 # the cell
-                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
+                cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
                     self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
 
-                if all([new_cell_isValid, transition_isValid, cell_isFree]):
+                if all([new_cell_valid, transition_valid, cell_free]):
                     agent.position = new_position
                     agent.direction = new_direction
                     agent.speed_data['position_fraction'] = 0.0
@@ -316,14 +314,14 @@ class RailEnv(Environment):
     def _check_action_on_agent(self, action, agent):
         # compute number of possible transitions in the current
         # cell used to check for invalid actions
-        new_direction, transition_isValid = self.check_action(agent, action)
+        new_direction, transition_valid = self.check_action(agent, action)
         new_position = get_new_position(agent.position, new_direction)
 
         # Is it a legal move?
         # 1) transition allows the new_direction in the cell,
         # 2) the new cell is not empty (case 0),
         # 3) the cell is free, i.e., no agent is currently in that cell
-        new_cell_isValid = (
+        new_cell_valid = (
             np.array_equal(  # Check the new position is still in the grid
                 new_position,
                 np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
@@ -331,19 +329,19 @@ class RailEnv(Environment):
             self.rail.get_transitions(new_position) > 0)
 
         # If transition validity hasn't been checked yet.
-        if transition_isValid is None:
-            transition_isValid = self.rail.get_transition(
+        if transition_valid is None:
+            transition_valid = self.rail.get_transition(
                 (*agent.position, agent.direction),
                 new_direction)
 
         # Check the new position is not the same as any of the existing agent positions
         # (including itself, for simplicity, since it is moving)
-        cell_isFree = not np.any(
+        cell_free = not np.any(
             np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
-        return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
+        return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
     def check_action(self, agent, action):
-        transition_isValid = None
+        transition_valid = None
         possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
 
@@ -351,12 +349,12 @@ class RailEnv(Environment):
         if action == RailEnvActions.MOVE_LEFT:
             new_direction = agent.direction - 1
             if num_transitions <= 1:
-                transition_isValid = False
+                transition_valid = False
 
         elif action == RailEnvActions.MOVE_RIGHT:
             new_direction = agent.direction + 1
             if num_transitions <= 1:
-                transition_isValid = False
+                transition_valid = False
 
         new_direction %= 4
 
@@ -366,8 +364,8 @@ class RailEnv(Environment):
                 # new_direction will be the only valid transition
                 # - take only available transition
                 new_direction = np.argmax(possible_transitions)
-                transition_isValid = True
-        return new_direction, transition_isValid
+                transition_valid = True
+        return new_direction, transition_valid
 
     def _get_observations(self):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..79f4bab164312f757ad584bd3708a7af3fb7a97e
--- /dev/null
+++ b/tests/test_distance_map.py
@@ -0,0 +1,56 @@
+import numpy as np
+
+from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.generators import rail_from_GridTransitionMap_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+
+
+def test_walker():
+    # _ _ _
+
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
+    transitions = Grid4Transitions([])
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+
+    rail_map = np.array(
+        [[dead_end_from_east] + [horizontal_straight] + [dead_end_from_west]], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2,
+                                                       predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
+                  )
+    # reset to initialize agents_static
+    env.reset()
+
+    # set initial position and direction for testing...
+    env.agents_static[0].position = (0, 1)
+    env.agents_static[0].direction = 1
+    env.agents_static[0].target = (0, 0)
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
+    obs_builder: TreeObsForRailEnv = env.obs_builder
+
+    print(obs_builder.distance_map[(0, *[0, 1], 1)])
+    assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3
+    print(obs_builder.distance_map[(0, *[0, 2], 3)])
+    assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2  # does not work yet, Erik's proposal.