From 52a015e1a8e85795926f54ee034a6318c6deb581 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Wed, 16 Dec 2020 20:13:41 +0100
Subject: [PATCH] refactored and added new agent

---
 reinforcement_learning/multi_agent_training.py |  2 +-
 reinforcement_learning/ppo_agent.py            | 11 +++++------
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index c127503..1092d84 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -172,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, get_action_size(), train_params)
-    if False:
+    if True:
         policy = PPOAgent(state_size, get_action_size())
     if False:
         policy = DeadLockAvoidanceAgent(train_env, get_action_size())
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
index 16206e9..f80bea8 100644
--- a/reinforcement_learning/ppo_agent.py
+++ b/reinforcement_learning/ppo_agent.py
@@ -166,11 +166,6 @@ class PPOAgent(Policy):
             if self.use_replay_buffer:
                 self.memory.add(state_i, action_i, discounted_reward, state_next_i, done_i)
 
-        if self.use_replay_buffer:
-            if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
-                states, actions, rewards, next_states, dones, prob_actions = self.memory.sample()
-                return states, actions, rewards, next_states, dones, prob_actions
-
         # convert data to torch tensors
         states, actions, rewards, states_next, dones, prob_actions = \
             torch.tensor(state_list, dtype=torch.float).to(self.device), \
@@ -195,7 +190,11 @@ class PPOAgent(Policy):
                 states, actions, rewards, states_next, dones, probs_action = \
                     self._convert_transitions_to_torch_tensors(agent_episode_history)
                 # Optimize policy for K epochs:
-                for _ in range(int(self.K_epoch)):
+                for k_loop in range(int(self.K_epoch)):
+                    if self.use_replay_buffer and k_loop > 0:
+                        if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
+                            states, actions, rewards, states_next, dones, probs_action = self.memory.sample()
+
                     # Evaluating actions (actor) and values (critic)
                     logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)
 
-- 
GitLab