From 016b9a58dfd58975931f06fc257c7e6f97afa064 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 3 Dec 2020 07:17:35 +0100
Subject: [PATCH] fix

---
 reinforcement_learning/ppo/ppo_agent.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
index 54ba25f..31d728e 100644
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -35,7 +35,8 @@ class DataBuffers:
         return self.memory.get(handle, [])
 
     def push_transition(self, handle, transition):
-        transitions = self.get_transitions(handle).append(transition)
+        transitions = self.get_transitions(handle)
+        transitions.append(transition)
         self.memory.update({handle: transitions})
 
 
@@ -103,7 +104,7 @@ class PPOAgent(Policy):
         return state, action, reward, s_next, done, prob_action
 
     def train_net(self):
-        for handle in range(self.n_agents):
+        for handle in range(len(self.memory)):
             agent_episode_history = self.memory.get_transitions(handle)
             if len(agent_episode_history) > 0:
                 # convert the replay buffer to torch tensors (arrays)
-- 
GitLab