diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 23cd25ad6b5830a73c07f01b0b3ebb5d93651e66..71e3efde65f90f61ed99fd138689982cfdd48e71 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -207,8 +207,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, action_size, train_params)
-    if False:
-        policy = PPOAgent(state_size, action_size, n_agents)
+    if True:
+        policy = PPOAgent(state_size, action_size)
     # Load existing policy
     if train_params.load_policy is not "":
         policy.load(train_params.load_policy)
@@ -256,8 +256,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Reset environment
         reset_timer.start()
-        number_of_agents = 2 # int(min(n_agents, 1 + np.floor(episode_idx / 200)))
-        train_env_params.n_agents = episode_idx % number_of_agents + 1
+        number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
+        train_env_params.n_agents = 1  # episode_idx % number_of_agents + 1
 
         train_env = create_rail_env(train_env_params, tree_observation)
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
@@ -397,7 +397,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
             env_renderer.close_window()
 
         # Print logs
-        if episode_idx % checkpoint_interval == 0:
+        if episode_idx % checkpoint_interval == 0 and episode_idx > 0:
             policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
 
             if save_replay_buffer:
@@ -548,7 +548,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int)
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=1000, type=int)
     parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0,
                         type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
@@ -581,7 +581,7 @@ if __name__ == "__main__":
     env_params = [
         {
             # Test_0
-            "n_agents": 5,
+            "n_agents": 1,
             "x_dim": 25,
             "y_dim": 25,
             "n_cities": 2,
@@ -592,6 +592,17 @@ if __name__ == "__main__":
         },
         {
             # Test_1
+            "n_agents": 5,
+            "x_dim": 25,
+            "y_dim": 25,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 50,
+            "seed": 0
+        },
+        {
+            # Test_2
             "n_agents": 10,
             "x_dim": 30,
             "y_dim": 30,
@@ -602,7 +613,7 @@ if __name__ == "__main__":
             "seed": 0
         },
         {
-            # Test_2
+            # Test_3
             "n_agents": 20,
             "x_dim": 35,
             "y_dim": 35,
diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py
index 7b15e3735e19035362952d4be00b7046876c114d..ddd45b5870e417aac176afd350561b921b283d83 100644
--- a/reinforcement_learning/multi_policy.py
+++ b/reinforcement_learning/multi_policy.py
@@ -12,7 +12,7 @@ class MultiPolicy(Policy):
         self.memory = []
         self.loss = 0
         self.deadlock_avoidance_policy = DeadLockAvoidanceAgent(env, action_size, False)
-        self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)
+        self.ppo_policy = PPOAgent(state_size + action_size, action_size)
 
     def load(self, filename):
         self.ppo_policy.load(filename)
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
index 6449c1298469da9c8d4ffd35ee6d1baaefefeb2b..21ce12d3653e7042e6e7f7585ff1cfaa0d2da104 100644
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -21,19 +21,27 @@ print("device:", device)
 
 
 class DataBuffers:
-    def __init__(self, n_agents):
-        self.memory = [[]] * n_agents
+    def __init__(self):
+        self.reset()
 
     def __len__(self):
         """Return the current size of internal memory."""
         return len(self.memory)
 
+    def reset(self):
+        self.memory = {}
+
+    def get_transitions(self, handle):
+        return self.memory.get(handle, [])
+
+    def push_transition(self, handle, transition):
+        self.memory.update({handle: self.get_transitions(handle).append(transition)})
+
 
 class PPOAgent(Policy):
-    def __init__(self, state_size, action_size, n_agents):
+    def __init__(self, state_size, action_size):
         super(PPOAgent, self).__init__()
-        self.n_agents = n_agents
-        self.memory = DataBuffers(n_agents)
+        self.memory = DataBuffers()
         self.loss = 0
         self.fc1 = nn.Linear(state_size, 256)
         self.fc_pi = nn.Linear(256, action_size)
@@ -60,75 +68,82 @@ class PPOAgent(Policy):
         a = m.sample().item()
         return a
 
-    # Record the results of the agent's action and update the model
     def step(self, handle, state, action, reward, next_state, done):
+        # Record the results of the agent's action as transition
         prob = self.pi(torch.from_numpy(state).float())
-        self.memory.memory[handle].append(((state, action, reward, next_state, prob[action].item(), done)))
+        transition = (state, action, reward, next_state, prob[action].item(), done)
+        self.memory.push_transition(handle, transition)
 
-    def make_batch(self, data_array):
-        s_lst, a_lst, r_lst, s_next_lst, prob_a_lst, done_lst = [], [], [], [], [], []
+    def _convert_transitions_to_torch(self, transitions_array):
+        state_list, action_list, reward_list, state_next_list, prob_a_list, done_list = [], [], [], [], [], []
         total_reward = 0
-        for transition in data_array:
-            s, a, r, s_next, prob_a, done = transition
-
-            s_lst.append(s)
-            a_lst.append([a])
-            if True:
-                total_reward += r
-                if done:
-                    r_lst.append([1])
-                else:
-                    r_lst.append([0])
+        for transition in transitions_array:
+            state_i, action_i, reward_i, state_next_i, prob_action_i, done_i = transition
+
+            state_list.append(state_i)
+            action_list.append([action_i])
+            total_reward += reward_i
+            if done_i:
+                reward_list.append([max(1.0, total_reward)])
+                done_list.append([1])
             else:
-                r_lst.append([r])
-            s_next_lst.append(s_next)
-            prob_a_lst.append([prob_a])
-            done_mask = 1 - int(done)
-            done_lst.append([done_mask])
+                reward_list.append([0])
+                done_list.append([0])
+            state_next_list.append(state_next_i)
+            prob_a_list.append([prob_action_i])
 
-        total_reward = max(1.0, total_reward)
-        for i in range(len(r_lst)):
-            r_lst[i][0] = r_lst[i][0] * total_reward
+        state, action, reward, s_next, done, prob_action = torch.tensor(state_list, dtype=torch.float), \
+                                                           torch.tensor(action_list), \
+                                                           torch.tensor(reward_list), \
+                                                           torch.tensor(state_next_list, dtype=torch.float), \
+                                                           torch.tensor(done_list, dtype=torch.float), \
+                                                           torch.tensor(prob_a_list)
 
-        s, a, r, s_next, done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), \
-                                             torch.tensor(a_lst), \
-                                             torch.tensor(r_lst), \
-                                             torch.tensor(s_next_lst, dtype=torch.float), \
-                                             torch.tensor(done_lst, dtype=torch.float), \
-                                             torch.tensor(prob_a_lst)
-
-        return s, a, r, s_next, done_mask, prob_a
+        return state, action, reward, s_next, done, prob_action
 
     def train_net(self):
         for handle in range(self.n_agents):
-            if len(self.memory.memory[handle]) > 0:
-                s, a, r, s_next, done_mask, prob_a = self.make_batch(self.memory.memory[handle])
-                for i in range(K_EPOCH):
-                    td_target = r + GAMMA * self.v(s_next) * done_mask
-                    delta = td_target - self.v(s)
-                    delta = delta.detach().numpy()
-
-                    advantage_lst = []
-                    advantage = 0.0
-                    for delta_t in delta[::-1]:
-                        advantage = GAMMA * LMBDA * advantage + delta_t[0]
-                        advantage_lst.append([advantage])
-                    advantage_lst.reverse()
-                    advantage = torch.tensor(advantage_lst, dtype=torch.float)
-
-                    pi = self.pi(s, softmax_dim=1)
-                    pi_a = pi.gather(1, a)
-                    ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))
-
-                    surr1 = ratio * advantage
-                    surr2 = torch.clamp(ratio, 1 - EPS_CLIP, 1 + EPS_CLIP) * advantage
-                    loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s), td_target.detach())
+            agent_episode_history = self.memory.get_transitions(handle)
+            if len(agent_episode_history) > 0:
+                # convert the replay buffer to torch tensors (arrays)
+                state, action, reward, state_next, done, prob_action = \
+                    self._convert_transitions_to_torch(agent_episode_history)
 
+                # run K_EPOCH optimisation steps
+                for i in range(K_EPOCH):
+                    # temporal difference function / and prepare advantage function data
+                    estimated_target_value = reward + GAMMA * self.v(state_next) * (1.0 - done)
+                    difference_to_expected_value_deltas = estimated_target_value - self.v(state)
+                    difference_to_expected_value_deltas = difference_to_expected_value_deltas.detach().numpy()
+
+                    # build advantage function and convert it to torch tensor (array)
+                    advantage_list = []
+                    advantage_value = 0.0
+                    for difference_to_expected_value_t in difference_to_expected_value_deltas[::-1]:
+                        advantage_value = LMBDA * advantage_value + difference_to_expected_value_t[0]
+                        advantage_list.append([advantage_value])
+                    advantage_list.reverse()
+                    advantage = torch.tensor(advantage_list, dtype=torch.float)
+
+                    pi_action = self.pi(state, softmax_dim=1).gather(1, action)
+                    ratio = torch.exp(torch.log(pi_action) - torch.log(prob_action))  # a/b == exp(log(a)-log(b))
+                    # Normal Policy Gradient objective
+                    surrogate_objective = ratio * advantage
+                    # clipped version of Normal Policy Gradient objective
+                    clipped_surrogate_objective = torch.clamp(ratio * advantage, 1 - EPS_CLIP, 1 + EPS_CLIP)
+                    # value function loss
+                    value_loss = F.mse_loss(self.v(state), estimated_target_value.detach())
+                    # loss
+                    loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss
+
+                    # update policy and actor networks
                     self.optimizer.zero_grad()
                     loss.mean().backward()
                     self.optimizer.step()
+
+                    # store current loss to the agent
                     self.loss = loss.mean().detach().numpy()
-        self.memory = DataBuffers(self.n_agents)
+        self.memory.reset()
 
     def end_episode(self, train):
         if train:
@@ -143,10 +158,10 @@ class PPOAgent(Policy):
         torch.save(self.optimizer.state_dict(), filename + ".optimizer")
 
     def _load(self, obj, filename):
-        if os.path.exists(filename + ".policy"):
-            print(' >> ', filename + ".policy")
+        if os.path.exists(filename):
+            print(' >> ', filename)
             try:
-                obj.load_state_dict(torch.load(filename + ".policy", map_location=device))
+                obj.load_state_dict(torch.load(filename, map_location=device))
             except:
                 print(" >> failed!")
         return obj
@@ -159,7 +174,7 @@ class PPOAgent(Policy):
         self.optimizer = self._load(self.optimizer, filename + ".optimizer")
 
     def clone(self):
-        policy = PPOAgent(self.state_size, self.action_size, self.num_agents)
+        policy = PPOAgent(self.state_size, self.action_size)
         policy.fc1 = copy.deepcopy(self.fc1)
         policy.fc_pi = copy.deepcopy(self.fc_pi)
         policy.fc_v = copy.deepcopy(self.fc_v)
diff --git a/run.py b/run.py
index d8d298cd355f19c4a4b278a0412a2c312ba65236..e31971e0305e1a3fd2e5b7a209809f47b4c93ef7 100644
--- a/run.py
+++ b/run.py
@@ -105,7 +105,7 @@ action_size = 5
 if not USE_PPO_AGENT:
     policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
 else:
-    policy = PPOAgent(state_size, action_size, 10)
+    policy = PPOAgent(state_size, action_size)
 policy.load(checkpoint)
 
 #####################################################################