diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 7a80c783d5d12d38aacc43952838d07a434d7d1d..c47b484c9d163482af343b5a49c5e918e36c0d2a 100644
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -194,8 +194,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         policy = PPOAgent(state_size, action_size, n_agents)
     if False:
         policy = MultiPolicy(state_size, action_size, n_agents, train_env)
-    if True:
-        policy = DeadLockAvoidanceAgent(train_env,state_size, action_size)
+    if False:
+        policy = DeadLockAvoidanceAgent(train_env, state_size, action_size)
 
     # Load existing policy
     if train_params.load_policy is not None:
@@ -253,6 +253,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         reset_timer.start()
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
         policy.reset()
+        deadLockAvoidanceAgent = DeadLockAvoidanceAgent(train_env, state_size, action_size)
         reset_timer.end()
 
         if train_params.render:
@@ -273,20 +274,25 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         for step in range(max_steps - 1):
             inference_timer.start()
             policy.start_step()
+            deadLockAvoidanceAgent.start_step()
             for agent in train_env.get_agent_handles():
-                if info['action_required'][agent]:
-                    update_values[agent] = True
-                    action = policy.act(agent,agent_obs[agent], eps=eps_start)
-
-                    action_count[action] += 1
-                    actions_taken.append(action)
-                else:
-                    # An action is not required if the train hasn't joined the railway network,
-                    # if it already reached its target, or if is currently malfunctioning.
-                    update_values[agent] = False
-                    action = 0
+                action = deadLockAvoidanceAgent.act(agent, None, 0.0)
+                update_values[agent] = False
+                if action != RailEnvActions.STOP_MOVING:
+                    if info['action_required'][agent]:
+                        update_values[agent] = True
+                        action = policy.act(agent, agent_obs[agent], eps=eps_start)
+                        action_count[action] += 1
+                        actions_taken.append(action)
+                    else:
+                        # An action is not required if the train hasn't joined the railway network,
+                        # if it already reached its target, or if is currently malfunctioning.
+                        action = 0
+
                 action_dict.update({agent: action})
             policy.end_step()
+            deadLockAvoidanceAgent.end_step()
+
             inference_timer.end()
 
             # Environment step
@@ -458,22 +464,26 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
         score = 0.0
 
         obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+        deadLockAvoidanceAgent = DeadLockAvoidanceAgent(env, None, None)
 
         final_step = 0
 
         for step in range(max_steps - 1):
+            deadLockAvoidanceAgent.start_step()
             for agent in env.get_agent_handles():
                 if tree_observation.check_is_observation_valid(agent_obs[agent]):
                     agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
                                                                                    observation_radius=observation_radius)
 
-                action = 0
-                if info['action_required'][agent]:
-                    if tree_observation.check_is_observation_valid(agent_obs[agent]):
-                        action = policy.act(agent,agent_obs[agent], eps=0.0)
+                action = deadLockAvoidanceAgent.act(agent, None, 0)
+                if action != RailEnvActions.STOP_MOVING:
+                    if info['action_required'][agent]:
+                        if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                            action = policy.act(agent, agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
 
             obs, all_rewards, done, info = env.step(action_dict)
+            deadLockAvoidanceAgent.end_step()
 
             for agent in env.get_agent_handles():
                 score += all_rewards[agent]
@@ -505,7 +515,7 @@ if __name__ == "__main__":
     parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=200000, type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
                         type=int)
-    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int)
     parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.0001, type=float)
@@ -525,9 +535,9 @@ if __name__ == "__main__":
     parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
     parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool)
-    parser.add_argument("--max_depth", help="max depth", default=-1, type=int)
+    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
     parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
     parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
 
     training_params = parser.parse_args()
diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py
index 2ce6d0f218870063fe546bc4b2674f355867a7e5..765bcf599f681f8b7a3dca311f223c5eed85e42d 100644
--- a/reinforcement_learning/multi_policy.py
+++ b/reinforcement_learning/multi_policy.py
@@ -13,7 +13,6 @@ class MultiPolicy(Policy):
         self.action_size = action_size
         self.memory = []
         self.loss = 0
-        self.dead_lock_avoidance_policy = DeadLockAvoidanceAgent(env, state_size, action_size)
         self.extra_policy = ExtraPolicy(state_size, action_size)
         self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)
 
@@ -40,9 +39,6 @@ class MultiPolicy(Policy):
         self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done)
 
     def act(self, handle, state, eps=0.):
-        dead_lock_avoidance_action = self.dead_lock_avoidance_policy.act(handle, state, 0.0)
-        if dead_lock_avoidance_action == RailEnvActions.STOP_MOVING:
-            return RailEnvActions.STOP_MOVING
         action_extra_state = self.extra_policy.act(handle, state, 0.0)
         extended_state = np.copy(state)
         for action_itr in np.arange(self.action_size):
@@ -60,11 +56,9 @@ class MultiPolicy(Policy):
         self.extra_policy.test()
 
     def start_step(self):
-        self.dead_lock_avoidance_policy.start_step()
         self.extra_policy.start_step()
         self.ppo_policy.start_step()
 
     def end_step(self):
-        self.dead_lock_avoidance_policy.end_step()
         self.extra_policy.end_step()
         self.ppo_policy.end_step()
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 382959aebf06df06faa50364ce569ee4e21f2cbb..bb9dc3d47b8f3a0d950f3acd377359b72350166d 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -1,14 +1,14 @@
 import matplotlib.pyplot as plt
 import numpy as np
 from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
 
 from reinforcement_learning.policy import Policy
 from utils.shortest_Distance_walker import ShortestDistanceWalker
 
 
 class MyWalker(ShortestDistanceWalker):
-    def __init__(self, env: RailEnv, agent_positions):
+    def __init__(self, env: RailEnv, agent_positions, switches):
         super().__init__(env)
         self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
                                                      self.env.height,
@@ -22,25 +22,38 @@ class MyWalker(ShortestDistanceWalker):
 
         self.agent_positions = agent_positions
 
-        self.agent_map = {}
+        self.opp_agent_map = {}
+        self.same_agent_map = {}
+        self.switches = switches
 
     def get_action(self, handle, min_distances):
+        if min_distances[0] != np.inf:
+            m = min(min_distances)
+            if min_distances[0] < m + 5:
+                return 0
         return np.argmin(min_distances)
 
     def getData(self):
         return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map
 
-    def callback(self, handle, agent, position, direction, action):
+    def callback(self, handle, agent, position, direction, action, possible_transitions):
         opp_a = self.agent_positions[position]
         if opp_a != -1 and opp_a != handle:
             if self.env.agents[opp_a].direction != direction:
-                d = self.agent_map.get(handle, [])
+                d = self.opp_agent_map.get(handle, [])
                 if opp_a not in d:
                     d.append(opp_a)
-                self.agent_map.update({handle: d})
-        d = self.agent_map.get(handle, [])
-        if len(d) == 0:
-            self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
+                self.opp_agent_map.update({handle: d})
+            else:
+                if len(self.opp_agent_map.get(handle, [])) == 0:
+                    d = self.same_agent_map.get(handle, [])
+                    if opp_a not in d:
+                        d.append(opp_a)
+                    self.same_agent_map.update({handle: d})
+
+        if len(self.opp_agent_map.get(handle, [])) == 0:
+            if self.switches.get(position, None) is None:
+                self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
         self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
 
 
@@ -52,6 +65,7 @@ class DeadLockAvoidanceAgent(Policy):
         self.memory = []
         self.loss = 0
         self.agent_can_move = {}
+        self.switches = {}
 
     def step(self, handle, state, action, reward, next_state, done):
         pass
@@ -61,11 +75,21 @@ class DeadLockAvoidanceAgent(Policy):
         check = self.agent_can_move.get(handle, None)
         if check is None:
             return RailEnvActions.STOP_MOVING
-
         return check[3]
 
     def reset(self):
-        pass
+        self.switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in self.switches.keys():
+                            self.switches.update({pos: [dir]})
+                        else:
+                            self.switches[pos].append(dir)
 
     def start_step(self):
         self.shortest_distance_mapper()
@@ -86,7 +110,7 @@ class DeadLockAvoidanceAgent(Policy):
                 if agent.position is not None:
                     agent_positions[agent.position] = handle
 
-        my_walker = MyWalker(self.env, agent_positions)
+        my_walker = MyWalker(self.env, agent_positions, self.switches)
         for handle in range(self.env.get_num_agents()):
             agent = self.env.agents[handle]
             if agent.status <= RailAgentStatus.ACTIVE:
@@ -96,14 +120,15 @@ class DeadLockAvoidanceAgent(Policy):
         self.agent_can_move = {}
         agent_positions_map = (agent_positions > -1).astype(int)
         for handle in range(self.env.get_num_agents()):
-            opp_agents = my_walker.agent_map.get(handle, [])
+            opp_agents = my_walker.opp_agent_map.get(handle, [])
+            same_agents = my_walker.same_agent_map.get(handle, [])
             me = shortest_distance_agent_map[handle]
             delta = me
             next_step_ok = True
-            next_position, next_direction, action = my_walker.walk_one_step(handle)
+            next_position, next_direction, action, possible_transitions = my_walker.walk_one_step(handle)
             for opp_a in opp_agents:
                 opp = full_shortest_distance_agent_map[opp_a]
-                delta = (delta - opp - agent_positions_map > 0).astype(int)
+                delta = ((delta - opp - agent_positions_map) > 0).astype(int)
                 if (np.sum(delta) < 3):
                     next_step_ok = False
 
diff --git a/utils/extra.py b/utils/extra.py
index 1145521ad3a3a087a38a4fb514fa4b60b4a4ffc0..83263c7398b5bdd2a60cdbef04a34da6535d23c4 100644
--- a/utils/extra.py
+++ b/utils/extra.py
@@ -84,7 +84,7 @@ class Extra(ObservationBuilder):
             def getData(self):
                 return self.shortest_distance_agent_counter, self.shortest_distance_agent_direction_counter
 
-            def callback(self, handle, agent, position, direction, action):
+            def callback(self, handle, agent, position, direction, action, possible_transitions):
                 self.shortest_distance_agent_counter[position] += 1
                 self.shortest_distance_agent_direction_counter[(position[0], position[1], direction)] += 1
 
diff --git a/utils/shortest_Distance_walker.py b/utils/shortest_Distance_walker.py
index bd1d5b33f3a58c15b75106efcb21ef11f7dd22cc..ad754b154dd28d8d3f7e0e376a2fd8eda8c89079 100644
--- a/utils/shortest_Distance_walker.py
+++ b/utils/shortest_Distance_walker.py
@@ -16,7 +16,7 @@ class ShortestDistanceWalker:
             new_position = get_new_position(position, new_direction)
 
             dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
-            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD
+            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions
         else:
             min_distances = []
             positions = []
@@ -34,28 +34,31 @@ class ShortestDistanceWalker:
                     directions.append(None)
 
         a = self.get_action(handle, min_distances)
-        return positions[a], directions[a], min_distances[a], a + 1
+        return positions[a], directions[a], min_distances[a], a + 1, possible_transitions
 
     def get_action(self, handle, min_distances):
         return np.argmin(min_distances)
 
-    def callback(self, handle, agent, position, direction, action):
+    def callback(self, handle, agent, position, direction, action, possible_transitions):
         pass
 
-    def walk_to_target(self, handle):
+    def walk_to_target(self, handle, max_step=500):
         agent = self.env.agents[handle]
         if agent.position is not None:
             position = agent.position
         else:
             position = agent.initial_position
         direction = agent.direction
-        while (position != agent.target):
-            position, direction, dist, action = self.walk(handle, position, direction)
+
+        step = 0
+        while (position != agent.target) and (step < max_step):
+            position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)
             if position is None:
                 break
-            self.callback(handle, agent, position, direction, action)
+            self.callback(handle, agent, position, direction, action, possible_transitions)
+            step += 1
 
-    def callback_one_step(self, handle, agent, position, direction, action):
+    def callback_one_step(self, handle, agent, position, direction, action, possible_transitions):
         pass
 
     def walk_one_step(self, handle):
@@ -65,9 +68,10 @@ class ShortestDistanceWalker:
         else:
             position = agent.initial_position
         direction = agent.direction
+        possible_transitions = (0, 1, 0, 0)
         if (position != agent.target):
-            new_position, new_direction, dist, action = self.walk(handle, position, direction)
+            new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction)
             if new_position is None:
-                return position, direction, RailEnvActions.STOP_MOVING
-            self.callback_one_step(handle, agent, new_position, new_direction, action)
-        return new_position, new_direction, action
+                return position, direction, RailEnvActions.STOP_MOVING, possible_transitions
+            self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions)
+        return new_position, new_direction, action, possible_transitions