diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index 54ba25fe1da709ca26587193561ac2e0a3b8b4a9..31d728ea5103cc5c4f5dc062fc780996f46e2d04 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -35,7 +35,8 @@ class DataBuffers: return self.memory.get(handle, []) def push_transition(self, handle, transition): - transitions = self.get_transitions(handle).append(transition) + transitions = self.get_transitions(handle) + transitions.append(transition) self.memory.update({handle: transitions}) @@ -103,7 +104,7 @@ class PPOAgent(Policy): return state, action, reward, s_next, done, prob_action def train_net(self): - for handle in range(self.n_agents): + for handle in range(len(self.memory)): agent_episode_history = self.memory.get_transitions(handle) if len(agent_episode_history) > 0: # convert the replay buffer to torch tensors (arrays)