diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 134a2c2e1ddb6c45d4ab806a3906b64c1298b529..f7cba9b6f9c94190e556f998b05c50fc5fa3c79b 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -16,8 +16,10 @@ from reinforcement_learning.policy import Policy
 class DDDQNPolicy(Policy):
     """Dueling Double DQN policy"""
 
-    def __init__(self, state_size, action_size, parameters, evaluation_mode=False):
-        self.parameters = parameters
+    def __init__(self, state_size, action_size, in_parameters, evaluation_mode=False):
+        super(Policy, self).__init__()
+
+        self.ddqn_parameters = in_parameters
         self.evaluation_mode = evaluation_mode
 
         self.state_size = state_size
@@ -26,17 +28,17 @@ class DDDQNPolicy(Policy):
         self.hidsize = 128
 
         if not evaluation_mode:
-            self.hidsize = parameters.hidden_size
-            self.buffer_size = parameters.buffer_size
-            self.batch_size = parameters.batch_size
-            self.update_every = parameters.update_every
-            self.learning_rate = parameters.learning_rate
-            self.tau = parameters.tau
-            self.gamma = parameters.gamma
-            self.buffer_min_size = parameters.buffer_min_size
+            self.hidsize = self.ddqn_parameters.hidden_size
+            self.buffer_size = self.ddqn_parameters.buffer_size
+            self.batch_size = self.ddqn_parameters.batch_size
+            self.update_every = self.ddqn_parameters.update_every
+            self.learning_rate = self.ddqn_parameters.learning_rate
+            self.tau = self.ddqn_parameters.tau
+            self.gamma = self.ddqn_parameters.gamma
+            self.buffer_min_size = self.ddqn_parameters.buffer_min_size
 
             # Device
-        if parameters.use_gpu and torch.cuda.is_available():
+        if self.ddqn_parameters.use_gpu and torch.cuda.is_available():
             self.device = torch.device("cuda:0")
             # print("🐇 Using GPU")
         else:
@@ -153,7 +155,7 @@ class DDDQNPolicy(Policy):
         self._learn()
 
     def clone(self):
-        me = DDDQNPolicy(self.state_size, self.action_size, self.parameters, evaluation_mode=True)
+        me = DDDQNPolicy(self.state_size, self.action_size, self.ddqn_parameters, evaluation_mode=True)
         me.qnetwork_target = copy.deepcopy(self.qnetwork_local)
         me.qnetwork_target = copy.deepcopy(self.qnetwork_target)
         return me
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 6944eab883acd92f87e46be40d93d6ad94cfa400..23cd25ad6b5830a73c07f01b0b3ebb5d93651e66 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -22,7 +22,6 @@ from torch.utils.tensorboard import SummaryWriter
 
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 from reinforcement_learning.ppo.ppo_agent import PPOAgent
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -208,7 +207,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, action_size, train_params)
-    if True:
+    if False:
         policy = PPOAgent(state_size, action_size, n_agents)
     # Load existing policy
     if train_params.load_policy is not "":
@@ -257,7 +256,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Reset environment
         reset_timer.start()
-        number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
+        number_of_agents = 2 # int(min(n_agents, 1 + np.floor(episode_idx / 200)))
         train_env_params.n_agents = episode_idx % number_of_agents + 1
 
         train_env = create_rail_env(train_env_params, tree_observation)
@@ -289,9 +288,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         max_steps = train_env._max_episode_steps
 
         # Run episode
+        policy.start_episode(train=True)
         for step in range(max_steps - 1):
             inference_timer.start()
-            policy.start_step()
+            policy.start_step(train=True)
             for agent_handle in train_env.get_agent_handles():
                 agent = train_env.agents[agent_handle]
                 if info['action_required'][agent_handle]:
@@ -306,7 +306,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                     update_values[agent_handle] = False
                     action = 0
                 action_dict.update({agent_handle: action})
-            policy.end_step()
+            policy.end_step(train=True)
             inference_timer.end()
 
             # Environment step
@@ -378,6 +378,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
             if done['__all__']:
                 break
 
+        policy.end_episode(train=True)
         # Epsilon decay
         eps_start = max(eps_end, eps_decay * eps_start)
 
@@ -507,8 +508,9 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
         obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
         final_step = 0
 
+        policy.start_episode(train=False)
         for step in range(max_steps - 1):
-            policy.start_step()
+            policy.start_step(train=False)
             for agent in env.get_agent_handles():
                 if tree_observation.check_is_observation_valid(agent_obs[agent]):
                     agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
@@ -519,7 +521,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
                     if tree_observation.check_is_observation_valid(agent_obs[agent]):
                         action = policy.act(agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
-            policy.end_step()
+            policy.end_step(train=False)
             obs, all_rewards, done, info = env.step(action_dict)
 
             for agent in env.get_agent_handles():
@@ -529,7 +531,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
             if done['__all__']:
                 break
-
+        policy.end_episode(train=False)
         normalized_score = score / (max_steps * env.get_num_agents())
         scores.append(normalized_score)
 
@@ -546,7 +548,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=10000, type=int)
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int)
     parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0,
                         type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py
index 52183c7fa80496857e372be0e44e8855815c321f..7b15e3735e19035362952d4be00b7046876c114d 100644
--- a/reinforcement_learning/multi_policy.py
+++ b/reinforcement_learning/multi_policy.py
@@ -1,5 +1,4 @@
 import numpy as np
-from flatland.envs.rail_env import RailEnvActions
 
 from reinforcement_learning.policy import Policy
 from reinforcement_learning.ppo.ppo_agent import PPOAgent
@@ -54,10 +53,10 @@ class MultiPolicy(Policy):
         self.ppo_policy.test()
         self.deadlock_avoidance_policy.test()
 
-    def start_step(self):
-        self.deadlock_avoidance_policy.start_step()
-        self.ppo_policy.start_step()
+    def start_step(self, train):
+        self.deadlock_avoidance_policy.start_step(train)
+        self.ppo_policy.start_step(train)
 
-    def end_step(self):
-        self.deadlock_avoidance_policy.end_step()
-        self.ppo_policy.end_step()
+    def end_step(self, train):
+        self.deadlock_avoidance_policy.end_step(train)
+        self.ppo_policy.end_step(train)
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index c7300de8dc843d8f08b86fcb43f1bc2993a7121b..33726d63e2c881c0a0db75bf1cd8f279b38fc1e7 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -1,4 +1,9 @@
-class Policy:
+import torch.nn as nn
+
+class Policy(nn.Module):
+    def __init__(self):
+        super(Policy, self).__init__()
+
     def step(self, handle, state, action, reward, next_state, done):
         raise NotImplementedError
 
@@ -11,10 +16,16 @@ class Policy:
     def load(self, filename):
         raise NotImplementedError
 
-    def start_step(self):
+    def start_step(self,train):
+        pass
+
+    def end_step(self,train):
+        pass
+
+    def start_episode(self,train):
         pass
 
-    def end_step(self):
+    def end_episode(self,train):
         pass
 
     def load_replay_buffer(self, filename):
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
index 09bafd79558e19f0266adba4742f42f6f3e373b1..6449c1298469da9c8d4ffd35ee6d1baaefefeb2b 100644
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -1,139 +1,167 @@
 import copy
 import os
 
-import numpy as np
 import torch
-from torch.distributions.categorical import Categorical
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.distributions import Categorical
 
+# Hyperparameters
 from reinforcement_learning.policy import Policy
-from reinforcement_learning.ppo.model import PolicyNetwork
-from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer
 
-BUFFER_SIZE = 128_000
-BATCH_SIZE = 8192
-GAMMA = 0.95
-LR = 0.5e-4
-CLIP_FACTOR = .005
-UPDATE_EVERY = 30
+LEARNING_RATE = 0.00001
+GAMMA = 0.98
+LMBDA = 0.95
+EPS_CLIP = 0.1
+K_EPOCH = 3
 
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 print("device:", device)
 
 
+class DataBuffers:
+    def __init__(self, n_agents):
+        self.memory = [[]] * n_agents
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+
 class PPOAgent(Policy):
-    def __init__(self, state_size, action_size, num_agents):
-        self.action_size = action_size
-        self.state_size = state_size
-        self.num_agents = num_agents
-        self.policy = PolicyNetwork(state_size, action_size).to(device)
-        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
-        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
-        self.episodes = [Episode() for _ in range(num_agents)]
-        self.memory = ReplayBuffer(BUFFER_SIZE)
-        self.t_step = 0
+    def __init__(self, state_size, action_size, n_agents):
+        super(PPOAgent, self).__init__()
+        self.n_agents = n_agents
+        self.memory = DataBuffers(n_agents)
         self.loss = 0
+        self.fc1 = nn.Linear(state_size, 256)
+        self.fc_pi = nn.Linear(256, action_size)
+        self.fc_v = nn.Linear(256, 1)
+        self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
 
     def reset(self):
-        self.finished = [False] * len(self.episodes)
-        self.tot_reward = [0] * self.num_agents
+        pass
 
-    # Decide on an action to take in the environment
+    def pi(self, x, softmax_dim=0):
+        x = F.tanh(self.fc1(x))
+        x = self.fc_pi(x)
+        prob = F.softmax(x, dim=softmax_dim)
+        return prob
+
+    def v(self, x):
+        x = F.tanh(self.fc1(x))
+        v = self.fc_v(x)
+        return v
 
     def act(self, state, eps=None):
-        self.policy.eval()
-        with torch.no_grad():
-            output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
-            ret = Categorical(output).sample().item()
-            return ret
+        prob = self.pi(torch.from_numpy(state).float())
+        m = Categorical(prob)
+        a = m.sample().item()
+        return a
 
     # Record the results of the agent's action and update the model
     def step(self, handle, state, action, reward, next_state, done):
-        if not self.finished[handle]:
-            # Push experience into Episode memory
-            self.tot_reward[handle] += reward
-            if done == 1:
-                reward = 1  # self.tot_reward[handle]
+        prob = self.pi(torch.from_numpy(state).float())
+        self.memory.memory[handle].append(((state, action, reward, next_state, prob[action].item(), done)))
+
+    def make_batch(self, data_array):
+        s_lst, a_lst, r_lst, s_next_lst, prob_a_lst, done_lst = [], [], [], [], [], []
+        total_reward = 0
+        for transition in data_array:
+            s, a, r, s_next, prob_a, done = transition
+
+            s_lst.append(s)
+            a_lst.append([a])
+            if True:
+                total_reward += r
+                if done:
+                    r_lst.append([1])
+                else:
+                    r_lst.append([0])
             else:
-                reward = 0
-
-            self.episodes[handle].push(state, action, reward, next_state, done)
-
-            # When we finish the episode, discount rewards and push the experience into replay memory
-            if done:
-                self.episodes[handle].discount_rewards(GAMMA)
-                self.memory.push_episode(self.episodes[handle])
-                self.episodes[handle].reset()
-                self.finished[handle] = True
-
-        # Perform a gradient update every UPDATE_EVERY time steps
-        self.t_step = (self.t_step + 1) % UPDATE_EVERY
-        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
-            self._learn(*self.memory.sample(BATCH_SIZE, device))
-
-    def _clip_gradient(self, model, clip):
-
-        for p in model.parameters():
-            p.grad.data.clamp_(-clip, clip)
-        return
-
-        """Computes a gradient clipping coefficient based on gradient norm."""
-        totalnorm = 0
-        for p in model.parameters():
-            if p.grad is not None:
-                modulenorm = p.grad.data.norm()
-                totalnorm += modulenorm ** 2
-        totalnorm = np.sqrt(totalnorm)
-        coeff = min(1, clip / (totalnorm + 1e-6))
-
-        for p in model.parameters():
-            if p.grad is not None:
-                p.grad.mul_(coeff)
-
-    def _learn(self, states, actions, rewards, next_state, done):
-        self.policy.train()
-
-        responsible_outputs = torch.gather(self.policy(states), 1, actions)
-        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
-
-        # rewards = rewards - rewards.mean()
-        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
-        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
-        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
-        self.loss = loss
-
-        # Compute loss and perform a gradient step
-        self.old_policy.load_state_dict(self.policy.state_dict())
-        self.optimizer.zero_grad()
-        loss.backward()
-        # self._clip_gradient(self.policy, 1.0)
-        self.optimizer.step()
+                r_lst.append([r])
+            s_next_lst.append(s_next)
+            prob_a_lst.append([prob_a])
+            done_mask = 1 - int(done)
+            done_lst.append([done_mask])
+
+        total_reward = max(1.0, total_reward)
+        for i in range(len(r_lst)):
+            r_lst[i][0] = r_lst[i][0] * total_reward
+
+        s, a, r, s_next, done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), \
+                                             torch.tensor(a_lst), \
+                                             torch.tensor(r_lst), \
+                                             torch.tensor(s_next_lst, dtype=torch.float), \
+                                             torch.tensor(done_lst, dtype=torch.float), \
+                                             torch.tensor(prob_a_lst)
+
+        return s, a, r, s_next, done_mask, prob_a
+
+    def train_net(self):
+        for handle in range(self.n_agents):
+            if len(self.memory.memory[handle]) > 0:
+                s, a, r, s_next, done_mask, prob_a = self.make_batch(self.memory.memory[handle])
+                for i in range(K_EPOCH):
+                    td_target = r + GAMMA * self.v(s_next) * done_mask
+                    delta = td_target - self.v(s)
+                    delta = delta.detach().numpy()
+
+                    advantage_lst = []
+                    advantage = 0.0
+                    for delta_t in delta[::-1]:
+                        advantage = GAMMA * LMBDA * advantage + delta_t[0]
+                        advantage_lst.append([advantage])
+                    advantage_lst.reverse()
+                    advantage = torch.tensor(advantage_lst, dtype=torch.float)
+
+                    pi = self.pi(s, softmax_dim=1)
+                    pi_a = pi.gather(1, a)
+                    ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))
+
+                    surr1 = ratio * advantage
+                    surr2 = torch.clamp(ratio, 1 - EPS_CLIP, 1 + EPS_CLIP) * advantage
+                    loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s), td_target.detach())
+
+                    self.optimizer.zero_grad()
+                    loss.mean().backward()
+                    self.optimizer.step()
+                    self.loss = loss.mean().detach().numpy()
+        self.memory = DataBuffers(self.n_agents)
+
+    def end_episode(self, train):
+        if train:
+            self.train_net()
 
     # Checkpointing methods
     def save(self, filename):
         # print("Saving model from checkpoint:", filename)
-        torch.save(self.policy.state_dict(), filename + ".policy")
+        torch.save(self.fc1.state_dict(), filename + ".fc1")
+        torch.save(self.fc_pi.state_dict(), filename + ".fc_pi")
+        torch.save(self.fc_v.state_dict(), filename + ".fc_v")
         torch.save(self.optimizer.state_dict(), filename + ".optimizer")
 
-    def load(self, filename):
-        print("load policy from file", filename)
+    def _load(self, obj, filename):
         if os.path.exists(filename + ".policy"):
             print(' >> ', filename + ".policy")
             try:
-                self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device))
+                obj.load_state_dict(torch.load(filename + ".policy", map_location=device))
             except:
                 print(" >> failed!")
-                pass
-        if os.path.exists(filename + ".optimizer"):
-            print(' >> ', filename + ".optimizer")
-            try:
-                self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device))
-            except:
-                print(" >> failed!")
-                pass
+        return obj
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        self.fc1 = self._load(self.fc1, filename + ".fc1")
+        self.fc_pi = self._load(self.fc_pi, filename + ".fc_pi")
+        self.fc_v = self._load(self.fc_v, filename + ".fc_v")
+        self.optimizer = self._load(self.optimizer, filename + ".optimizer")
 
     def clone(self):
         policy = PPOAgent(self.state_size, self.action_size, self.num_agents)
-        policy.policy = copy.deepcopy(self.policy)
-        policy.old_policy = copy.deepcopy(self.old_policy)
+        policy.fc1 = copy.deepcopy(self.fc1)
+        policy.fc_pi = copy.deepcopy(self.fc_pi)
+        policy.fc_v = copy.deepcopy(self.fc_v)
         policy.optimizer = copy.deepcopy(self.optimizer)
         return self
diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py
index bfcc88656c8b37a8c09e72b51701d0750cf7f238..a5ee6c5132652757cdd8fd8dab6992e08fdfd14b 100644
--- a/reinforcement_learning/single_agent_training.py
+++ b/reinforcement_learning/single_agent_training.py
@@ -103,9 +103,9 @@ def train_agent(n_episodes):
         'buffer_size': int(1e5),
         'batch_size': 32,
         'update_every': 8,
-        'learning_rate': 0.5e-4,
+        'LEARNING_RATE': 0.5e-4,
         'tau': 1e-3,
-        'gamma': 0.99,
+        'GAMMA': 0.99,
         'buffer_min_size': 0,
         'hidden_size': 256,
         'use_gpu': False
diff --git a/run.py b/run.py
index a5e91a35c6b27af1e647e0d0efd812eaa40bcfbe..d8d298cd355f19c4a4b278a0412a2c312ba65236 100644
--- a/run.py
+++ b/run.py
@@ -163,6 +163,7 @@ while True:
     if USE_DEAD_LOCK_AVOIDANCE_AGENT:
         policy = DeadLockAvoidanceAgent(local_env, action_size)
 
+    policy.start_episode(train=False)
     while True:
         try:
             #####################################################################
@@ -175,7 +176,7 @@ while True:
             if not check_if_all_blocked(env=local_env):
                 time_start = time.time()
                 action_dict = {}
-                policy.start_step()
+                policy.start_step(train=False)
                 if USE_DEAD_LOCK_AVOIDANCE_AGENT:
                     observation = np.zeros((local_env.get_num_agents(), 2))
                 for agent_handle in range(nb_agents):
@@ -203,7 +204,7 @@ while True:
                         agent_last_obs[agent_handle] = observation[agent_handle]
                         agent_last_action[agent_handle] = action
 
-                policy.end_step()
+                policy.end_step(train=False)
                 agent_time = time.time() - time_start
                 time_taken_by_controller.append(agent_time)
 
@@ -254,6 +255,8 @@ while True:
             print("Timeout! Will skip this episode and go to the next.", err)
             break
 
+    policy.end_episode(train=False)
+
     np_time_taken_by_controller = np.array(time_taken_by_controller)
     np_time_taken_per_step = np.array(time_taken_per_step)
     print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 4a371350333bbe6e8295331ada53e8f8ada83b3b..07840db7b28505ea228db35e4e10f961c4015313 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -110,12 +110,12 @@ class DeadLockAvoidanceAgent(Policy):
                         else:
                             self.switches[pos].append(dir)
 
-    def start_step(self):
+    def start_step(self, train):
         self.build_agent_position_map()
         self.shortest_distance_mapper()
         self.extract_agent_can_move()
 
-    def end_step(self):
+    def end_step(self, train):
         pass
 
     def get_actions(self):
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index 3238ee54bdf64986ea77c8a766e958d86a8c34eb..b104916caeb02b95334c84b1477018bc0c5908b4 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -76,7 +76,7 @@ class FastTreeObs(ObservationBuilder):
                                     self.switches_neighbours[pos].append(dir)
 
     def find_all_cell_where_agent_can_choose(self):
-        # prepare the data - collect all cells where the agent can choose more than FORWARD/STOP.
+        # prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP.
         self.find_all_switches()
         self.find_all_switch_neighbours()
 
@@ -243,9 +243,9 @@ class FastTreeObs(ObservationBuilder):
         return has_opp_agent, has_same_agent, has_target, visited, min_dist
 
     def get_many(self, handles: Optional[List[int]] = None):
-        self.dead_lock_avoidance_agent.start_step()
+        self.dead_lock_avoidance_agent.start_step(train=False)
         observations = super().get_many(handles)
-        self.dead_lock_avoidance_agent.end_step()
+        self.dead_lock_avoidance_agent.end_step(train=False)
         return observations
 
     def get(self, handle):