diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index b43cf9480230d9e3ad8eded82a77ffaf897fcc97..1d99a9b80d51d8ff34991efbeb802af2a373a292 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -23,7 +23,7 @@ sys.path.append(str(base_dir))
 
 from utils.timer import Timer
 from utils.observation_utils import normalize_observation
-from utils.fast_tree_obs import FastTreeObs
+from utils.fast_tree_obs import FastTreeObs, fast_tree_obs_check_agent_deadlock
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 
 try:
@@ -156,12 +156,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     # The action space of flatland is 5 discrete actions
     action_size = 5
 
-    # Max number of steps per episode
-    # This is the official formula used during evaluations
-    # See details in flatland.envs.schedule_generators.sparse_schedule_generator
-    # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
-    max_steps = train_env._max_episode_steps
-
     action_count = [0] * action_size
     action_dict = dict()
     agent_obs = [None] * n_agents
@@ -229,6 +223,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Reset environment
         reset_timer.start()
+        train_env_params.n_agents = episode_idx % n_agents + 1
+        train_env = create_rail_env(train_env_params, tree_observation)
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
         reset_timer.end()
 
@@ -246,6 +242,12 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                                                                                observation_radius=observation_radius)
                 agent_prev_obs[agent] = agent_obs[agent].copy()
 
+        # Max number of steps per episode
+        # This is the official formula used during evaluations
+        # See details in flatland.envs.schedule_generators.sparse_schedule_generator
+        # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+        max_steps = train_env._max_episode_steps
+
         # Run episode
         for step in range(max_steps - 1):
             inference_timer.start()
@@ -286,8 +288,18 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 if update_values[agent] or done['__all__']:
                     # Only learn from timesteps where somethings happened
                     learn_timer.start()
-                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent],
-                                done[agent])
+                    call_step = True
+
+                    if not (agent_obs[agent][7] == 1 or agent_obs[agent][8] == 1 or agent_obs[agent][4] == 1):
+                        if action_dict.get(agent) == RailEnvActions.MOVE_FORWARD:
+                            call_step = np.random.random() < 0.1
+                    if fast_tree_obs_check_agent_deadlock(agent_obs[agent]):
+                        all_rewards[agent] -= 10
+                        call_step = True
+                    if call_step:
+                        policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
+                                    agent_obs[agent],
+                                    done[agent])
                     learn_timer.end()
 
                     agent_prev_obs[agent] = agent_obs[agent].copy()
@@ -474,8 +486,8 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12500, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int)
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
                         type=int)
     parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
@@ -500,7 +512,7 @@ if __name__ == "__main__":
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
     parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
                         action='store_true')
-    parser.add_argument("--max_depth", help="max depth", default=2, type=int)
+    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
 
     training_params = parser.parse_args()
     env_params = [
diff --git a/reinforcement_learning/ppo/model.py b/reinforcement_learning/ppo/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d72c9ca4059d45a139305b69ee95f709977e07
--- /dev/null
+++ b/reinforcement_learning/ppo/model.py
@@ -0,0 +1,20 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PolicyNetwork(nn.Module):
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
+        super().__init__()
+        self.fc1 = nn.Linear(state_size, hidsize1)
+        self.fc2 = nn.Linear(hidsize1, hidsize2)
+        # self.fc3 = nn.Linear(hidsize2, hidsize3)
+        self.output = nn.Linear(hidsize2, action_size)
+        self.softmax = nn.Softmax(dim=1)
+        self.bn0 = nn.BatchNorm1d(state_size, affine=False)
+
+    def forward(self, inputs):
+        x = self.bn0(inputs.float())
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        # x = F.relu(self.fc3(x))
+        return self.softmax(self.output(x))
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43cb3080bf5187e148f0b06e69749f46e840e9e
--- /dev/null
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -0,0 +1,141 @@
+import os
+import random
+
+import numpy as np
+import torch
+from torch.distributions.categorical import Categorical
+
+from reinforcement_learning.policy import Policy
+from reinforcement_learning.ppo.model import PolicyNetwork
+from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer
+
+BUFFER_SIZE = 128_000
+BATCH_SIZE = 8192
+GAMMA = 0.95
+LR = 0.5e-4
+CLIP_FACTOR = .005
+UPDATE_EVERY = 30
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+class PPOAgent(Policy):
+    def __init__(self, state_size, action_size, num_agents, env):
+        self.action_size = action_size
+        self.state_size = state_size
+        self.num_agents = num_agents
+        self.policy = PolicyNetwork(state_size, action_size).to(device)
+        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
+        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
+        self.episodes = [Episode() for _ in range(num_agents)]
+        self.memory = ReplayBuffer(BUFFER_SIZE)
+        self.t_step = 0
+        self.loss = 0
+        self.env = env
+
+    def reset(self):
+        self.finished = [False] * len(self.episodes)
+        self.tot_reward = [0] * self.num_agents
+
+    # Decide on an action to take in the environment
+
+    def act(self, handle, state, eps=None):
+        if True:
+            self.policy.eval()
+            with torch.no_grad():
+                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
+                return Categorical(output).sample().item()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            self.policy.eval()
+            with torch.no_grad():
+                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
+                return Categorical(output).sample().item()
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    # Record the results of the agent's action and update the model
+    def step(self, handle, state, action, reward, next_state, done):
+        if not self.finished[handle]:
+            # Push experience into Episode memory
+            self.tot_reward[handle] += reward
+            if done == 1:
+                reward = 1  # self.tot_reward[handle]
+            else:
+                reward = 0
+
+            self.episodes[handle].push(state, action, reward, next_state, done)
+
+            # When we finish the episode, discount rewards and push the experience into replay memory
+            if done:
+                self.episodes[handle].discount_rewards(GAMMA)
+                self.memory.push_episode(self.episodes[handle])
+                self.episodes[handle].reset()
+                self.finished[handle] = True
+
+        # Perform a gradient update every UPDATE_EVERY time steps
+        self.t_step = (self.t_step + 1) % UPDATE_EVERY
+        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
+            self._learn(*self.memory.sample(BATCH_SIZE, device))
+
+    def _clip_gradient(self, model, clip):
+
+        for p in model.parameters():
+            p.grad.data.clamp_(-clip, clip)
+        return
+
+        """Computes a gradient clipping coefficient based on gradient norm."""
+        totalnorm = 0
+        for p in model.parameters():
+            if p.grad is not None:
+                modulenorm = p.grad.data.norm()
+                totalnorm += modulenorm ** 2
+        totalnorm = np.sqrt(totalnorm)
+        coeff = min(1, clip / (totalnorm + 1e-6))
+
+        for p in model.parameters():
+            if p.grad is not None:
+                p.grad.mul_(coeff)
+
+    def _learn(self, states, actions, rewards, next_state, done):
+        self.policy.train()
+
+        responsible_outputs = torch.gather(self.policy(states), 1, actions)
+        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
+
+        # rewards = rewards - rewards.mean()
+        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
+        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
+        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
+        self.loss = loss
+
+        # Compute loss and perform a gradient step
+        self.old_policy.load_state_dict(self.policy.state_dict())
+        self.optimizer.zero_grad()
+        loss.backward()
+        # self._clip_gradient(self.policy, 1.0)
+        self.optimizer.step()
+
+    # Checkpointing methods
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        torch.save(self.policy.state_dict(), filename + ".policy")
+        torch.save(self.optimizer.state_dict(), filename + ".optimizer")
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        if os.path.exists(filename + ".policy"):
+            print(' >> ', filename + ".policy")
+            try:
+                self.policy.load_state_dict(torch.load(filename + ".policy"))
+            except:
+                print(" >> failed!")
+                pass
+        if os.path.exists(filename + ".optimizer"):
+            print(' >> ', filename + ".optimizer")
+            try:
+                self.optimizer.load_state_dict(torch.load(filename + ".optimizer"))
+            except:
+                print(" >> failed!")
+                pass
diff --git a/reinforcement_learning/ppo/replay_memory.py b/reinforcement_learning/ppo/replay_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..61a1b81bf3a338e031f3d0441058268206426ce7
--- /dev/null
+++ b/reinforcement_learning/ppo/replay_memory.py
@@ -0,0 +1,53 @@
+import torch
+import random
+import numpy as np
+from collections import namedtuple, deque, Iterable
+
+
+Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done"))
+
+
+class Episode:
+    memory = []
+
+    def reset(self):
+        self.memory = []
+
+    def push(self, *args):
+        self.memory.append(tuple(args))
+
+    def discount_rewards(self, gamma):
+        running_add = 0.
+        for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]:
+            running_add = running_add * gamma + reward
+            self.memory[i] = (state, action, running_add, *rest)
+
+
+class ReplayBuffer:
+    def __init__(self, buffer_size):
+        self.memory = deque(maxlen=buffer_size)
+
+    def push(self, state, action, reward, next_state, done):
+        self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done))
+
+    def push_episode(self, episode):
+        for step in episode.memory:
+            self.push(*step)
+
+    def sample(self, batch_size, device):
+        experiences = random.sample(self.memory, k=batch_size)
+
+        states      = torch.from_numpy(self.stack([e.state      for e in experiences])).float().to(device)
+        actions     = torch.from_numpy(self.stack([e.action     for e in experiences])).long().to(device)
+        rewards     = torch.from_numpy(self.stack([e.reward     for e in experiences])).float().to(device)
+        next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device)
+        dones       = torch.from_numpy(self.stack([e.done       for e in experiences]).astype(np.uint8)).float().to(device)
+
+        return states, actions, rewards, next_states, dones
+
+    def stack(self, states):
+        sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1]
+        return np.reshape(np.array(states), (len(states), *sub_dims))
+
+    def __len__(self):
+        return len(self.memory)
diff --git a/run.py b/run.py
index e9950fcf1335b52db2a3b7a15baa8d164f62c7be..ac6a3cb3589a93c43e44e68fe6faf9796d31235e 100644
--- a/run.py
+++ b/run.py
@@ -25,14 +25,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
 VERBOSE = True
 
 # Checkpoint to use (remember to push it!)
-checkpoint = "./checkpoints/201103221432-3000.pth"
+checkpoint = "./checkpoints/201105173637-4700.pth" # 18.50097663335293 : Depth = 1
 
 # Use last action cache
 USE_ACTION_CACHE = True
 USE_DEAD_LOCK_AVOIDANCE_AGENT = False
 
 # Observation parameters (must match training parameters!)
-observation_tree_depth = 2
+observation_tree_depth = 1
 observation_radius = 10
 observation_max_path_depth = 30
 
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index b2a4bf72353fc87d8dcde2f430f1277d27c422bb..919be23b0eaafa1be76869a0c90ff18dd647e773 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -23,7 +23,7 @@ class FastTreeObs(ObservationBuilder):
 
     def __init__(self, max_depth):
         self.max_depth = max_depth
-        self.observation_dim = 30
+        self.observation_dim = 32
 
     def build_data(self):
         if self.env is not None:
@@ -287,7 +287,7 @@ class FastTreeObs(ObservationBuilder):
                     has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
                     visited.append(v)
 
-                    observation[10 + dir_loop] = 1
+                    observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist))
                     observation[14 + dir_loop] = has_opp_agent
                     observation[18 + dir_loop] = has_same_agent
                     observation[22 + dir_loop] = has_switch
@@ -301,12 +301,25 @@ class FastTreeObs(ObservationBuilder):
         observation[8] = int(agents_near_to_switch)
         observation[9] = int(agents_near_to_switch_all)
 
-        action = self.dead_lock_avoidance_agent.act([handle],0.0)
+        action = self.dead_lock_avoidance_agent.act([handle], 0.0)
         observation[26] = int(action == RailEnvActions.STOP_MOVING)
         observation[27] = int(action == RailEnvActions.MOVE_LEFT)
         observation[28] = int(action == RailEnvActions.MOVE_FORWARD)
         observation[29] = int(action == RailEnvActions.MOVE_RIGHT)
-
+        observation[30] = int(self.full_action_required(observation))
+        observation[31] = int(fast_tree_obs_check_agent_deadlock(observation))
         self.env.dev_obs_dict.update({handle: visited})
 
         return observation
+
+    def full_action_required(self, observation):
+        return observation[7] == 1 or observation[8] == 1 or observation[4] == 1
+
+
+def fast_tree_obs_check_agent_deadlock(observation):
+    nbr_of_path = 0
+    nbr_of_blocked_path = 0
+    for dir_loop in range(4):
+        nbr_of_path += observation[10 + dir_loop]
+        nbr_of_blocked_path += int(observation[14 + dir_loop] > 0)
+    return nbr_of_path <= nbr_of_blocked_path