diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index f7cba9b6f9c94190e556f998b05c50fc5fa3c79b..7a3525d903d323487507f991e6bcb099a436b35e 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -55,11 +55,13 @@ class DDDQNPolicy(Policy):
             self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
             self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate)
             self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
-
             self.t_step = 0
             self.loss = 0.0
+        else:
+            self.memory = ReplayBuffer(action_size, 1, 1, self.device)
+            self.loss = 0.0
 
-    def act(self, state, eps=0.):
+    def act(self, handle, state, eps=0.):
         state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
         self.qnetwork_local.eval()
         with torch.no_grad():
@@ -151,7 +153,7 @@ class DDDQNPolicy(Policy):
             self.memory.memory = pickle.load(f)
 
     def test(self):
-        self.act(np.array([[0] * self.state_size]))
+        self.act(0, np.array([[0] * self.state_size]))
         self._learn()
 
     def clone(self):
diff --git a/reinforcement_learning/evaluate_agent.py b/reinforcement_learning/evaluate_agent.py
index 64eb9433a9df00457d698e5873b31a16712de718..5488f81eae52753a071ef18142a5514579dd4c5c 100644
--- a/reinforcement_learning/evaluate_agent.py
+++ b/reinforcement_learning/evaluate_agent.py
@@ -26,7 +26,8 @@ from utils.observation_utils import normalize_observation
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 
 
-def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching):
+def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render,
+                allow_skipping, allow_caching):
     # Evaluation is faster on CPU (except if you use a really huge policy)
     parameters = {
         'use_gpu': False
@@ -140,11 +141,12 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
 
                     else:
                         preproc_timer.start()
-                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
+                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth,
+                                                         observation_radius=observation_radius)
                         preproc_timer.end()
 
                         inference_timer.start()
-                        action = policy.act(norm_obs, eps=0.0)
+                        action = policy.act(agent, norm_obs, eps=0.0)
                         inference_timer.end()
 
                     action_dict.update({agent: action})
@@ -319,12 +321,15 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
 
     results = []
     if render:
-        results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching))
+        results.append(
+            eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping,
+                        allow_caching))
 
     else:
         with Pool() as p:
             results = p.starmap(eval_policy,
-                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching)
+                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render,
+                                  allow_skipping, allow_caching)
                                  for seed in
                                  range(total_nb_eval)])
 
@@ -367,10 +372,12 @@ if __name__ == "__main__":
 
     parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
     parser.add_argument("--render", help="render a single episode", action='store_true')
-    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
+    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked",
+                        action='store_true')
     parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true')
     args = parser.parse_args()
 
     os.environ["OMP_NUM_THREADS"] = str(1)
-    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render,
+    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu,
+                    render=args.render,
                     allow_skipping=args.allow_skipping, allow_caching=args.allow_caching)
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index dac599225c528dbd0a5f84c729cbb2568f942c51..e13584f491392974251b2134e29d9a5736dbac93 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -21,6 +21,8 @@ from torch.utils.tensorboard import SummaryWriter
 
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 from reinforcement_learning.ppo_agent import PPOAgent
+from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 from utils.deadlock_check import get_agent_positions, check_for_deadlock
 
 base_dir = Path(__file__).resolve().parent.parent
@@ -174,6 +176,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     policy = DDDQNPolicy(state_size, action_size, train_params)
     if True:
         policy = PPOAgent(state_size, action_size)
+    if False:
+        policy = DeadLockAvoidanceAgent(train_env, action_size)
+    if True:
+        policy = MultiDecisionAgent(train_env, state_size, action_size, policy)
+
     # Load existing policy
     if train_params.load_policy is not "":
         policy.load(train_params.load_policy)
@@ -226,7 +233,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         train_env = create_rail_env(train_env_params, tree_observation)
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
-        policy.reset()
+        policy.reset(train_env)
         reset_timer.end()
 
         if train_params.render:
@@ -261,8 +268,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 agent = train_env.agents[agent_handle]
                 if info['action_required'][agent_handle]:
                     update_values[agent_handle] = True
-                    action = policy.act(agent_obs[agent_handle], eps=eps_start)
-
+                    action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start)
                     action_count[action] += 1
                     actions_taken.append(action)
                 else:
@@ -288,7 +294,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                         all_rewards[agent_handle] = 0.0
                         if done[agent_handle] == False:
                             if check_for_deadlock(agent_handle, train_env, agent_positions):
-                                all_rewards[agent_handle] = -1.0
+                                all_rewards[agent_handle] = -5.0
                             else:
                                 pos = agent.position
                                 possible_transitions = train_env.rail.get_transitions(*pos, agent.direction)
@@ -471,6 +477,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
         score = 0.0
 
         obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+        policy.reset(env)
         final_step = 0
 
         policy.start_episode(train=False)
@@ -484,7 +491,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
                 action = 0
                 if info['action_required'][agent]:
                     if tree_observation.check_is_observation_valid(agent_obs[agent]):
-                        action = policy.act(agent_obs[agent], eps=0.0)
+                        action = policy.act(agent, agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
             policy.end_step(train=False)
             obs, all_rewards, done, info = env.step(action_dict)
diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py
index 87763b1ab0e528627247f246ce734b7ddcbe55ab..0c2ae32144216bef5a95cc2ec43eb6ac027bfc7f 100644
--- a/reinforcement_learning/multi_policy.py
+++ b/reinforcement_learning/multi_policy.py
@@ -1,4 +1,5 @@
 import numpy as np
+from flatland.envs.rail_env import RailEnv
 
 from reinforcement_learning.policy import Policy
 from reinforcement_learning.ppo_agent import PPOAgent
@@ -45,9 +46,9 @@ class MultiPolicy(Policy):
         self.loss = self.ppo_policy.loss
         return action_ppo
 
-    def reset(self):
-        self.ppo_policy.reset()
-        self.deadlock_avoidance_policy.reset()
+    def reset(self, env: RailEnv):
+        self.ppo_policy.reset(env)
+        self.deadlock_avoidance_policy.reset(env)
 
     def test(self):
         self.ppo_policy.test()
diff --git a/reinforcement_learning/ordered_policy.py b/reinforcement_learning/ordered_policy.py
index daf6639d33052eedc5b69481e84413edea552eee..2db171d2e1429a085488b02f9818ba75c57b2694 100644
--- a/reinforcement_learning/ordered_policy.py
+++ b/reinforcement_learning/ordered_policy.py
@@ -15,7 +15,7 @@ class OrderedPolicy(Policy):
     def __init__(self):
         self.action_size = 5
 
-    def act(self, state, eps=0.):
+    def act(self, handle, state, eps=0.):
         _, distance, _ = split_tree_into_feature_groups(state, 1)
         distance = distance[1:]
         min_dist = min_gt(distance, 0)
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index 45889da1780a9188d85cedcd6c40c899fe088c51..5b118aee15253d7dfb86c04925ea8a058abdbf2d 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -1,10 +1,11 @@
-import torch.nn as nn
+from flatland.envs.rail_env import RailEnv
+
 
 class Policy:
     def step(self, handle, state, action, reward, next_state, done):
         raise NotImplementedError
 
-    def act(self, state, eps=0.):
+    def act(self, handle, state, eps=0.):
         raise NotImplementedError
 
     def save(self, filename):
@@ -13,16 +14,16 @@ class Policy:
     def load(self, filename):
         raise NotImplementedError
 
-    def start_step(self,train):
+    def start_step(self, train):
         pass
 
-    def end_step(self,train):
+    def end_step(self, train):
         pass
 
-    def start_episode(self,train):
+    def start_episode(self, train):
         pass
 
-    def end_episode(self,train):
+    def end_episode(self, train):
         pass
 
     def load_replay_buffer(self, filename):
@@ -31,8 +32,8 @@ class Policy:
     def test(self):
         pass
 
-    def reset(self):
+    def reset(self, env: RailEnv):
         pass
 
     def clone(self):
-        return self
\ No newline at end of file
+        return self
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
index ee179d19bf49fb7f3fe2d2ed47eb3b9123d257e4..a4e74ec884fd99288e45353a56d52cf31a27555f 100644
--- a/reinforcement_learning/ppo_agent.py
+++ b/reinforcement_learning/ppo_agent.py
@@ -9,7 +9,7 @@ from torch.distributions import Categorical
 # Hyperparameters
 from reinforcement_learning.policy import Policy
 
-device = torch.device("cpu")#"cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")  # "cuda:0" if torch.cuda.is_available() else "cpu")
 print("device:", device)
 
 
@@ -111,10 +111,10 @@ class PPOAgent(Policy):
         self.optimizer = optim.Adam(self.actor_critic_model.parameters(), lr=self.learning_rate)
         self.loss_function = nn.SmoothL1Loss()  # nn.MSELoss()
 
-    def reset(self):
+    def reset(self, env):
         pass
 
-    def act(self, state, eps=None):
+    def act(self, handle, state, eps=None):
         # sample a action to take
         torch_state = torch.tensor(state, dtype=torch.float).to(device)
         dist = self.actor_critic_model.get_actor_dist(torch_state)
@@ -148,10 +148,8 @@ class PPOAgent(Policy):
                 reward_i = 1
             else:
                 done_list.insert(0, 0)
-                if reward_i < -1:
-                    reward_i = -1
-                else:
-                    reward_i = 0
+                reward_i = 0
+
             discounted_reward = reward_i + self.gamma * discounted_reward
             reward_list.insert(0, discounted_reward)
             state_next_list.insert(0, state_next_i)
diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3cf21638a4f04fba1b91e4cacbd668b62ce5996
--- /dev/null
+++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py
@@ -0,0 +1,81 @@
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+
+from reinforcement_learning.policy import Policy
+from utils.agent_can_choose_helper import AgentCanChooseHelper
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+
+
+class MultiDecisionAgent(Policy):
+
+    def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
+        self.env = env
+        self.state_size = state_size
+        self.action_size = action_size
+        self.learning_agent = learning_agent
+        self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False)
+        self.agent_can_choose_helper = AgentCanChooseHelper()
+        self.memory = self.learning_agent.memory
+        self.loss = self.learning_agent.loss
+
+    def step(self, handle, state, action, reward, next_state, done):
+        self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done)
+        self.learning_agent.step(handle, state, action, reward, next_state, done)
+        self.loss = self.learning_agent.loss
+
+    def act(self, handle, state, eps=0.):
+        agent = self.env.agents[handle]
+        position = agent.position
+        if position is None:
+            position = agent.initial_position
+        direction = agent.direction
+        if agent.status < RailAgentStatus.DONE:
+            agents_on_switch, agents_near_to_switch, _, _ = \
+                self.agent_can_choose_helper.check_agent_decision(position, direction)
+            if agents_on_switch or agents_near_to_switch:
+                return self.learning_agent.act(handle, state, eps)
+            else:
+                return self.dead_lock_avoidance_agent.act(handle, state, -1.0)
+        # Agent is still at target cell
+        return RailEnvActions.DO_NOTHING
+
+    def save(self, filename):
+        self.dead_lock_avoidance_agent.save(filename)
+        self.learning_agent.save(filename)
+
+    def load(self, filename):
+        self.dead_lock_avoidance_agent.load(filename)
+        self.learning_agent.load(filename)
+
+    def start_step(self, train):
+        self.dead_lock_avoidance_agent.start_step(train)
+        self.learning_agent.start_step(train)
+
+    def end_step(self, train):
+        self.dead_lock_avoidance_agent.end_step(train)
+        self.learning_agent.end_step(train)
+
+    def start_episode(self, train):
+        self.dead_lock_avoidance_agent.start_episode(train)
+        self.learning_agent.start_episode(train)
+
+    def end_episode(self, train):
+        self.dead_lock_avoidance_agent.end_episode(train)
+        self.learning_agent.end_episode(train)
+
+    def load_replay_buffer(self, filename):
+        self.dead_lock_avoidance_agent.load_replay_buffer(filename)
+        self.learning_agent.load_replay_buffer(filename)
+
+    def test(self):
+        self.dead_lock_avoidance_agent.test()
+        self.learning_agent.test()
+
+    def reset(self, env: RailEnv):
+        self.env = env
+        self.agent_can_choose_helper.build_data(env)
+        self.dead_lock_avoidance_agent.reset(env)
+        self.learning_agent.reset(env)
+
+    def clone(self):
+        return self
diff --git a/reinforcement_learning/sequential_agent.py b/reinforcement_learning/sequential_agent.py
index 3bb5a73cdc42f33a5e771eeaf530cf4af9742be8..e2055a69576454a0252a24e21408db0f04131da0 100644
--- a/reinforcement_learning/sequential_agent.py
+++ b/reinforcement_learning/sequential_agent.py
@@ -1,13 +1,13 @@
 import sys
-import numpy as np
+from pathlib import Path
 
+import numpy as np
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
-from pathlib import Path
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -73,7 +73,7 @@ for trials in range(1, n_episodes + 1):
             if done[a]:
                 acting_agent += 1
             if a == acting_agent:
-                action = policy.act(obs[a])
+                action = policy.act(a, obs[a])
             else:
                 action = 4
             action_dict.update({a: action})
diff --git a/reinforcement_learning/sequential_agent_training.py b/reinforcement_learning/sequential_agent_training.py
index ca19d1fcbbb4e3508a16b847d4b4cfcefc6aad98..d1ddd4348a462a9b7c17d6dae36c780acff1fd8b 100644
--- a/reinforcement_learning/sequential_agent_training.py
+++ b/reinforcement_learning/sequential_agent_training.py
@@ -1,13 +1,13 @@
 import sys
-import numpy as np
+from pathlib import Path
 
+import numpy as np
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
-from pathlib import Path
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -66,7 +66,7 @@ for trials in range(1, n_episodes + 1):
             if done[a]:
                 acting_agent += 1
             if a == acting_agent:
-                action = policy.act(obs[a])
+                action = policy.act(a, obs[a])
             else:
                 action = 4
             action_dict.update({a: action})
diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py
index bfcc88656c8b37a8c09e72b51701d0750cf7f238..dda07a9db5b6da3c2185f65d259fa0a9cf549c50 100644
--- a/reinforcement_learning/single_agent_training.py
+++ b/reinforcement_learning/single_agent_training.py
@@ -123,7 +123,8 @@ def train_agent(n_episodes):
         # Build agent specific observations
         for agent in env.get_agent_handles():
             if obs[agent]:
-                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius)
+                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth,
+                                                         observation_radius=observation_radius)
                 agent_prev_obs[agent] = agent_obs[agent].copy()
 
         # Run episode
@@ -132,7 +133,7 @@ def train_agent(n_episodes):
                 if info['action_required'][agent]:
                     # If an action is required, we want to store the obs at that step as well as the action
                     update_values = True
-                    action = policy.act(agent_obs[agent], eps=eps_start)
+                    action = policy.act(agent, agent_obs[agent], eps=eps_start)
                     action_count[action] += 1
                 else:
                     update_values = False
@@ -154,7 +155,8 @@ def train_agent(n_episodes):
                     agent_prev_action[agent] = action_dict[agent]
 
                 if next_obs[agent]:
-                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=10)
+                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth,
+                                                             observation_radius=10)
 
                 score += all_rewards[agent]
 
@@ -179,15 +181,16 @@ def train_agent(n_episodes):
         else:
             end = " "
 
-        print('\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
-            env.get_num_agents(),
-            x_dim, y_dim,
-            episode_idx,
-            np.mean(scores_window),
-            100 * np.mean(completion_window),
-            eps_start,
-            action_probs
-        ), end=end)
+        print(
+            '\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                env.get_num_agents(),
+                x_dim, y_dim,
+                episode_idx,
+                np.mean(scores_window),
+                100 * np.mean(completion_window),
+                eps_start,
+                action_probs
+            ), end=end)
 
     # Plot overall training progress at the end
     plt.plot(scores)
@@ -199,7 +202,8 @@ def train_agent(n_episodes):
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500, type=int)
+    parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500,
+                        type=int)
     args = parser.parse_args()
 
     train_agent(args.n_episodes)
diff --git a/run.py b/run.py
index 8eb8f8109c498a2cf0f9bd27a9174a577a41e240..1b1d11fd79aa3aef4a87e5e043f319e0c507edf1 100644
--- a/run.py
+++ b/run.py
@@ -31,6 +31,7 @@ from flatland.evaluators.client import FlatlandRemoteClient
 from flatland.evaluators.client import TimeoutException
 
 from reinforcement_learning.ppo_agent import PPOAgent
+from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 from utils.deadlock_check import check_if_all_blocked
 from utils.fast_tree_obs import FastTreeObs
@@ -50,20 +51,23 @@ USE_FAST_TREEOBS = True
 USE_PPO_AGENT = True
 
 # Checkpoint to use (remember to push it!)
-checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
+checkpoint = "./checkpoints/201124171810-7800.pth"  # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
 # checkpoint = "./checkpoints/201126150143-5200.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
 # checkpoint = "./checkpoints/201126160144-2000.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
-checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786
-checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857
-checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504
-checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537
-checkpoint = "./checkpoints/201213181400-6800.pth" # PPO: 13.944402986414723
+checkpoint = "./checkpoints/201207144650-20000.pth"  # PPO: 14.45790721540786
+checkpoint = "./checkpoints/201211063511-6300.pth"  # DDDQN: 16.948349308440857
+checkpoint = "./checkpoints/201211095604-12000.pth"  # DDDQN: 17.3862941316504
+checkpoint = "./checkpoints/201211164554-9400.pth"  # DDDQN: 16.09241366013537
+checkpoint = "./checkpoints/201213181400-6800.pth"  # PPO: 13.944402986414723
+checkpoint = "./checkpoints/201214140158-5000.pth"  # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723
+checkpoint = "./checkpoints/201214160604-3000.pth"  # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723
 
 EPSILON = 0.0
 
 # Use last action cache
 USE_ACTION_CACHE = False
 USE_DEAD_LOCK_AVOIDANCE_AGENT = False  # 21.54485505223213
+USE_MULTI_DECISION_AGENT = True
 
 # Observation parameters (must match training parameters!)
 observation_tree_depth = 2
@@ -106,10 +110,10 @@ action_size = 5
 
 # Creates the policy. No GPU on evaluation server.
 if not USE_PPO_AGENT:
-    policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
+    trained_policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
 else:
-    policy = PPOAgent(state_size, action_size)
-policy.load(checkpoint)
+    trained_policy = PPOAgent(state_size, action_size)
+trained_policy.load(checkpoint)
 
 #####################################################################
 # Main evaluation loop
@@ -144,6 +148,11 @@ while True:
 
     tree_observation.set_env(local_env)
     tree_observation.reset()
+
+    policy = trained_policy
+    if USE_MULTI_DECISION_AGENT:
+        policy = MultiDecisionAgent(local_env, state_size, action_size, trained_policy)
+    policy.reset(local_env)
     observation = tree_observation.get_many(list(range(nb_agents)))
 
     print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
@@ -199,7 +208,7 @@ while True:
                                                                                 observation_tree_depth,
                                                                                 observation_radius=observation_radius)
 
-                            action = policy.act(normalized_observation, eps=EPSILON)
+                            action = policy.act(agent_handle, normalized_observation, eps=EPSILON)
 
                     action_dict[agent_handle] = action
 
diff --git a/utils/agent_can_choose_helper.py b/utils/agent_can_choose_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..95636c6a13dc847e798ff283bfb4ccf8fc871a64
--- /dev/null
+++ b/utils/agent_can_choose_helper.py
@@ -0,0 +1,107 @@
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import fast_count_nonzero
+
+
+class AgentCanChooseHelper:
+    def __init__(self):
+        pass
+
+    def build_data(self, env):
+        self.env = env
+        if self.env is not None:
+            self.env.dev_obs_dict = {}
+        self.switches = {}
+        self.switches_neighbours = {}
+        if self.env is not None:
+            self.find_all_cell_where_agent_can_choose()
+
+    def find_all_switches(self):
+        # Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation
+        # exists and collect all direction where the switch is a switch.
+        self.switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in self.switches.keys():
+                            self.switches.update({pos: [dir]})
+                        else:
+                            self.switches[pos].append(dir)
+
+    def find_all_switch_neighbours(self):
+        # Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make
+        # just one step and he stands on a switch. A switch is a cell where the agents has more than one transition.
+        self.switches_neighbours = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                # look one step forward
+                for dir in range(4):
+                    pos = (h, w)
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    for d in range(4):
+                        if possible_transitions[d] == 1:
+                            new_cell = get_new_position(pos, d)
+                            if new_cell in self.switches.keys() and pos not in self.switches.keys():
+                                if pos not in self.switches_neighbours.keys():
+                                    self.switches_neighbours.update({pos: [dir]})
+                                else:
+                                    self.switches_neighbours[pos].append(dir)
+
+    def find_all_cell_where_agent_can_choose(self):
+        # prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP.
+        self.find_all_switches()
+        self.find_all_switch_neighbours()
+
+    def check_agent_decision(self, position, direction):
+        # Decide whether the agent is
+        # - on a switch
+        # - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than
+        #   FORWARD/STOP
+        # - all switch : doesn't matter whether the agent has more options than FORWARD/STOP
+        # - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the
+        #   switch
+        agents_on_switch = False
+        agents_on_switch_all = False
+        agents_near_to_switch = False
+        agents_near_to_switch_all = False
+        if position in self.switches.keys():
+            agents_on_switch = direction in self.switches[position]
+            agents_on_switch_all = True
+
+        if position in self.switches_neighbours.keys():
+            new_cell = get_new_position(position, direction)
+            if new_cell in self.switches.keys():
+                if not direction in self.switches[new_cell]:
+                    agents_near_to_switch = direction in self.switches_neighbours[position]
+            else:
+                agents_near_to_switch = direction in self.switches_neighbours[position]
+
+            agents_near_to_switch_all = direction in self.switches_neighbours[position]
+
+        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def required_agent_decision(self):
+        agents_can_choose = {}
+        agents_on_switch = {}
+        agents_on_switch_all = {}
+        agents_near_to_switch = {}
+        agents_near_to_switch_all = {}
+        for a in range(self.env.get_num_agents()):
+            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
+                self.check_agent_decision(
+                    self.env.agents[a].position,
+                    self.env.agents[a].direction)
+            agents_on_switch.update({a: ret_agents_on_switch})
+            agents_on_switch_all.update({a: ret_agents_on_switch_all})
+            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
+            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
+
+            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
+
+            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
+
+        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 07840db7b28505ea228db35e4e10f961c4015313..286718ea4a86ae75f9d32e72476f5e48f14a558f 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -66,10 +66,18 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
         self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
 
 
+class DummyMemory:
+    def __init__(self):
+        self.memory = []
+
+    def __len__(self):
+        return 0
+
+
 class DeadLockAvoidanceAgent(Policy):
     def __init__(self, env: RailEnv, action_size, show_debug_plot=False):
         self.env = env
-        self.memory = None
+        self.memory = DummyMemory()
         self.loss = 0
         self.action_size = action_size
         self.agent_can_move = {}
@@ -77,16 +85,16 @@ class DeadLockAvoidanceAgent(Policy):
         self.switches = {}
         self.show_debug_plot = show_debug_plot
 
-    def step(self, state, action, reward, next_state, done):
+    def step(self, handle, state, action, reward, next_state, done):
         pass
 
-    def act(self, state, eps=0.):
+    def act(self, handle, state, eps=0.):
         # Epsilon-greedy action selection
         if np.random.random() < eps:
             return np.random.choice(np.arange(self.action_size))
 
         # agent = self.env.agents[state[0]]
-        check = self.agent_can_move.get(state[0], None)
+        check = self.agent_can_move.get(handle, None)
         if check is None:
             return RailEnvActions.STOP_MOVING
         return check[3]
@@ -94,7 +102,8 @@ class DeadLockAvoidanceAgent(Policy):
     def get_agent_can_move_value(self, handle):
         return self.agent_can_move_value.get(handle, np.inf)
 
-    def reset(self):
+    def reset(self, env):
+        self.env = env
         self.agent_positions = None
         self.shortest_distance_walker = None
         self.switches = {}
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index c45477db4dcb91ebbea36b729ebc6ff142bec000..8172703b51cf754981c03e487c08f51ea400aec0 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import List, Optional, Any
 
 import numpy as np
 from flatland.core.env_observation_builder import ObservationBuilder
@@ -6,6 +6,7 @@ from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
 
+from utils.agent_can_choose_helper import AgentCanChooseHelper
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 from utils.deadlock_check import check_for_deadlock, get_agent_positions
 
@@ -24,116 +25,14 @@ Author: Adrian Egli (adrian.egli@gmail.com)
 
 class FastTreeObs(ObservationBuilder):
 
-    def __init__(self, max_depth):
+    def __init__(self, max_depth: Any):
         self.max_depth = max_depth
         self.observation_dim = 41
-
-    def build_data(self):
-        if self.env is not None:
-            self.env.dev_obs_dict = {}
-        self.switches = {}
-        self.switches_neighbours = {}
-        self.debug_render_list = []
-        self.debug_render_path_list = []
-        if self.env is not None:
-            self.find_all_cell_where_agent_can_choose()
-            self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False)
-        else:
-            self.dead_lock_avoidance_agent = None
-
-    def find_all_switches(self):
-        # Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation
-        # exists and collect all direction where the switch is a switch.
-        self.switches = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                pos = (h, w)
-                for dir in range(4):
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    num_transitions = fast_count_nonzero(possible_transitions)
-                    if num_transitions > 1:
-                        if pos not in self.switches.keys():
-                            self.switches.update({pos: [dir]})
-                        else:
-                            self.switches[pos].append(dir)
-
-    def find_all_switch_neighbours(self):
-        # Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make
-        # just one step and he stands on a switch. A switch is a cell where the agents has more than one transition.
-        self.switches_neighbours = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                # look one step forward
-                for dir in range(4):
-                    pos = (h, w)
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    for d in range(4):
-                        if possible_transitions[d] == 1:
-                            new_cell = get_new_position(pos, d)
-                            if new_cell in self.switches.keys() and pos not in self.switches.keys():
-                                if pos not in self.switches_neighbours.keys():
-                                    self.switches_neighbours.update({pos: [dir]})
-                                else:
-                                    self.switches_neighbours[pos].append(dir)
-
-    def find_all_cell_where_agent_can_choose(self):
-        # prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP.
-        self.find_all_switches()
-        self.find_all_switch_neighbours()
-
-    def check_agent_decision(self, position, direction):
-        # Decide whether the agent is
-        # - on a switch
-        # - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than
-        #   FORWARD/STOP
-        # - all switch : doesn't matter whether the agent has more options than FORWARD/STOP
-        # - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the
-        #   switch
-        agents_on_switch = False
-        agents_on_switch_all = False
-        agents_near_to_switch = False
-        agents_near_to_switch_all = False
-        if position in self.switches.keys():
-            agents_on_switch = direction in self.switches[position]
-            agents_on_switch_all = True
-
-        if position in self.switches_neighbours.keys():
-            new_cell = get_new_position(position, direction)
-            if new_cell in self.switches.keys():
-                if not direction in self.switches[new_cell]:
-                    agents_near_to_switch = direction in self.switches_neighbours[position]
-            else:
-                agents_near_to_switch = direction in self.switches_neighbours[position]
-
-            agents_near_to_switch_all = direction in self.switches_neighbours[position]
-
-        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
-
-    def required_agent_decision(self):
-        agents_can_choose = {}
-        agents_on_switch = {}
-        agents_on_switch_all = {}
-        agents_near_to_switch = {}
-        agents_near_to_switch_all = {}
-        for a in range(self.env.get_num_agents()):
-            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
-                self.check_agent_decision(
-                    self.env.agents[a].position,
-                    self.env.agents[a].direction)
-            agents_on_switch.update({a: ret_agents_on_switch})
-            agents_on_switch_all.update({a: ret_agents_on_switch_all})
-            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
-            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
-
-            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
-
-            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
-
-        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+        self.agent_can_choose_helper = None
 
     def debug_render(self, env_renderer):
         agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
-            self.required_agent_decision()
+            self.agent_can_choose_helper.required_agent_decision()
         self.env.dev_obs_dict = {}
         for a in range(max(3, self.env.get_num_agents())):
             self.env.dev_obs_dict.update({a: []})
@@ -156,13 +55,20 @@ class FastTreeObs(ObservationBuilder):
         env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
 
         self.env.dev_obs_dict[0] = self.debug_render_list
-        self.env.dev_obs_dict[1] = self.switches.keys()
-        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
+        self.env.dev_obs_dict[1] = self.agent_can_choose_helper.switches.keys()
+        self.env.dev_obs_dict[2] = self.agent_can_choose_helper.switches_neighbours.keys()
         self.env.dev_obs_dict[3] = self.debug_render_path_list
 
     def reset(self):
-        self.build_data()
-        return
+        if self.agent_can_choose_helper is None:
+            self.agent_can_choose_helper = AgentCanChooseHelper()
+        self.agent_can_choose_helper.build_data(self.env)
+        self.debug_render_list = []
+        self.debug_render_path_list = []
+        if self.env is not None:
+            self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False)
+        else:
+            self.dead_lock_avoidance_agent = None
 
     def _explore(self, handle, new_position, new_direction, distance_map, depth=0):
         has_opp_agent = 0
@@ -201,7 +107,7 @@ class FastTreeObs(ObservationBuilder):
             # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
             #
             agents_on_switch, agents_near_to_switch, _, _ = \
-                self.check_agent_decision(new_position, new_direction)
+                self.agent_can_choose_helper.check_agent_decision(new_position, new_direction)
 
             if agents_near_to_switch:
                 # The exploration was walking on a path where the agent can not decide
@@ -250,7 +156,7 @@ class FastTreeObs(ObservationBuilder):
         self.dead_lock_avoidance_agent.end_step(train=False)
         return observations
 
-    def get(self, handle):
+    def get(self, handle: int = 0):
         # all values are [0,1]
         # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
         # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
@@ -340,14 +246,14 @@ class FastTreeObs(ObservationBuilder):
             agents_near_to_switch, \
             agents_near_to_switch_all, \
             agents_on_switch_all = \
-                self.check_agent_decision(agent_virtual_position, agent.direction)
+                self.agent_can_choose_helper.check_agent_decision(agent_virtual_position, agent.direction)
 
             observation[7] = int(agents_on_switch)
             observation[8] = int(agents_on_switch_all)
             observation[9] = int(agents_near_to_switch)
             observation[10] = int(agents_near_to_switch_all)
 
-            action = self.dead_lock_avoidance_agent.act([handle], 0.0)
+            action = self.dead_lock_avoidance_agent.act(handle, None, 0.0)
             observation[35] = int(action == RailEnvActions.STOP_MOVING)
 
             observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions))
diff --git a/utils/shortest_path_walker_heuristic_agent.py b/utils/shortest_path_walker_heuristic_agent.py
index eaa71e91a416b0c899519e690c4e29ad8147a48d..d2cbab04f407edeae3fba5030a0b7b3309560cfc 100644
--- a/utils/shortest_path_walker_heuristic_agent.py
+++ b/utils/shortest_path_walker_heuristic_agent.py
@@ -8,7 +8,7 @@ class ShortestPathWalkerHeuristicPolicy(Policy):
     def step(self, state, action, reward, next_state, done):
         pass
 
-    def act(self, node, eps=0.):
+    def act(self, handle, node, eps=0.):
 
         left_node = node.childs.get('L')
         forward_node = node.childs.get('F')