diff --git a/run.py b/run.py
index 7ad08cd45122f68de9da711022269f4255727a69..fec8bc01cff56b1733a016666eb703ae560353c0 100644
--- a/run.py
+++ b/run.py
@@ -29,6 +29,7 @@ checkpoint = "./checkpoints/201103160541-1800.pth"
 
 # Use last action cache
 USE_ACTION_CACHE = True
+USE_DEAD_LOCK_AVOIDANCE_AGENT = False
 
 # Observation parameters (must match training parameters!)
 observation_tree_depth = 2
@@ -82,8 +83,6 @@ while True:
     nb_agents = len(local_env.agents)
     max_nb_steps = local_env._max_episode_steps
 
-    policy = DeadLockAvoidanceAgent(local_env)
-
     tree_observation.set_env(local_env)
     tree_observation.reset()
     observation = tree_observation.get_many(list(range(nb_agents)))
@@ -105,6 +104,9 @@ while True:
     agent_last_action = {}
     nb_hit = 0
 
+    if USE_DEAD_LOCK_AVOIDANCE_AGENT:
+        policy = DeadLockAvoidanceAgent(local_env)
+
     while True:
         try:
             #####################################################################
@@ -118,7 +120,14 @@ while True:
                 time_start = time.time()
                 action_dict = {}
                 policy.start_step()
+                if USE_DEAD_LOCK_AVOIDANCE_AGENT:
+                    observation = np.zeros((local_env.get_num_agents(), 2))
                 for agent in range(nb_agents):
+
+                    if USE_DEAD_LOCK_AVOIDANCE_AGENT:
+                        observation[agent][0] = agent
+                        observation[agent][1] = steps
+
                     if info['action_required'][agent]:
                         if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
                             # cache hit
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index c7c6d8a5c171eb8cc208520bf6dc8f5d6cfc2845..700600c337882271eebe519233b473449be3b1ab 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -7,7 +7,7 @@ from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
 
 from reinforcement_learning.policy import Policy
-from utils.shortest_Distance_walker import ShortestDistanceWalker
+from utils.shortest_distance_walker import ShortestDistanceWalker
 
 
 class DeadlockAvoidanceObservation(DummyObservationBuilder):
diff --git a/utils/shortest_distance_walker.py b/utils/shortest_distance_walker.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b8121f46e681ebc37ea3d1afb6b4023d33f2e14
--- /dev/null
+++ b/utils/shortest_distance_walker.py
@@ -0,0 +1,87 @@
+import numpy as np
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
+
+
+class ShortestDistanceWalker:
+    def __init__(self, env: RailEnv):
+        self.env = env
+
+    def walk(self, handle, position, direction):
+        possible_transitions = self.env.rail.get_transitions(*position, direction)
+        num_transitions = fast_count_nonzero(possible_transitions)
+        if num_transitions == 1:
+            new_direction = fast_argmax(possible_transitions)
+            new_position = get_new_position(position, new_direction)
+
+            dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
+            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions
+        else:
+            min_distances = []
+            positions = []
+            directions = []
+            for new_direction in [(direction + i) % 4 for i in range(-1, 2)]:
+                if possible_transitions[new_direction]:
+                    new_position = get_new_position(position, new_direction)
+                    min_distances.append(
+                        self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction])
+                    positions.append(new_position)
+                    directions.append(new_direction)
+                else:
+                    min_distances.append(np.inf)
+                    positions.append(None)
+                    directions.append(None)
+
+        a = self.get_action(handle, min_distances)
+        return positions[a], directions[a], min_distances[a], a + 1, possible_transitions
+
+    def get_action(self, handle, min_distances):
+        return np.argmin(min_distances)
+
+    def callback(self, handle, agent, position, direction, action, possible_transitions):
+        pass
+
+    def get_agent_position_and_direction(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        return position, direction
+
+    def walk_to_target(self, handle, position=None, direction=None, max_step=500):
+        if position is None and direction is None:
+            position, direction = self.get_agent_position_and_direction(handle)
+        elif position is None:
+            position, _ = self.get_agent_position_and_direction(handle)
+        elif direction is None:
+            _, direction = self.get_agent_position_and_direction(handle)
+
+        agent = self.env.agents[handle]
+        step = 0
+        while (position != agent.target) and (step < max_step):
+            position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)
+            if position is None:
+                break
+            self.callback(handle, agent, position, direction, action, possible_transitions)
+            step += 1
+
+    def callback_one_step(self, handle, agent, position, direction, action, possible_transitions):
+        pass
+
+    def walk_one_step(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        possible_transitions = (0, 1, 0, 0)
+        if (position != agent.target):
+            new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction)
+            if new_position is None:
+                return position, direction, RailEnvActions.STOP_MOVING, possible_transitions
+            self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions)
+        return new_position, new_direction, action, possible_transitions