diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index c12750356efcccf7225bbe51e0c13fbfd44ac41b..1092d84e5e5fb2ef45424fb0ecd36823621d62a4 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -172,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, get_action_size(), train_params)
-    if False:
+    if True:
         policy = PPOAgent(state_size, get_action_size())
     if False:
         policy = DeadLockAvoidanceAgent(train_env, get_action_size())
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
index 16206e9f6330e8cfe1b96ea2b797e4ea4f1ecdaf..f80bea86ebf89cd62f8e5310ddbbab88dd15acb1 100644
--- a/reinforcement_learning/ppo_agent.py
+++ b/reinforcement_learning/ppo_agent.py
@@ -166,11 +166,6 @@ class PPOAgent(Policy):
             if self.use_replay_buffer:
                 self.memory.add(state_i, action_i, discounted_reward, state_next_i, done_i)
 
-        if self.use_replay_buffer:
-            if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
-                states, actions, rewards, next_states, dones, prob_actions = self.memory.sample()
-                return states, actions, rewards, next_states, dones, prob_actions
-
         # convert data to torch tensors
         states, actions, rewards, states_next, dones, prob_actions = \
             torch.tensor(state_list, dtype=torch.float).to(self.device), \
@@ -195,7 +190,11 @@ class PPOAgent(Policy):
                 states, actions, rewards, states_next, dones, probs_action = \
                     self._convert_transitions_to_torch_tensors(agent_episode_history)
                 # Optimize policy for K epochs:
-                for _ in range(int(self.K_epoch)):
+                for k_loop in range(int(self.K_epoch)):
+                    if self.use_replay_buffer and k_loop > 0:
+                        if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
+                            states, actions, rewards, states_next, dones, probs_action = self.memory.sample()
+
                     # Evaluating actions (actor) and values (critic)
                     logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)