From 52a015e1a8e85795926f54ee034a6318c6deb581 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Wed, 16 Dec 2020 20:13:41 +0100 Subject: [PATCH] refactored and added new agent --- reinforcement_learning/multi_agent_training.py | 2 +- reinforcement_learning/ppo_agent.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index c127503..1092d84 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -172,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Double Dueling DQN policy policy = DDDQNPolicy(state_size, get_action_size(), train_params) - if False: + if True: policy = PPOAgent(state_size, get_action_size()) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 16206e9..f80bea8 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -166,11 +166,6 @@ class PPOAgent(Policy): if self.use_replay_buffer: self.memory.add(state_i, action_i, discounted_reward, state_next_i, done_i) - if self.use_replay_buffer: - if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size: - states, actions, rewards, next_states, dones, prob_actions = self.memory.sample() - return states, actions, rewards, next_states, dones, prob_actions - # convert data to torch tensors states, actions, rewards, states_next, dones, prob_actions = \ torch.tensor(state_list, dtype=torch.float).to(self.device), \ @@ -195,7 +190,11 @@ class PPOAgent(Policy): states, actions, rewards, states_next, dones, probs_action = \ self._convert_transitions_to_torch_tensors(agent_episode_history) # Optimize policy for K epochs: - for _ in range(int(self.K_epoch)): + for k_loop in range(int(self.K_epoch)): + if self.use_replay_buffer and k_loop > 0: + if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size: + states, actions, rewards, states_next, dones, probs_action = self.memory.sample() + # Evaluating actions (actor) and values (critic) logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions) -- GitLab