diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index c12750356efcccf7225bbe51e0c13fbfd44ac41b..1092d84e5e5fb2ef45424fb0ecd36823621d62a4 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -172,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Double Dueling DQN policy policy = DDDQNPolicy(state_size, get_action_size(), train_params) - if False: + if True: policy = PPOAgent(state_size, get_action_size()) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 16206e9f6330e8cfe1b96ea2b797e4ea4f1ecdaf..f80bea86ebf89cd62f8e5310ddbbab88dd15acb1 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -166,11 +166,6 @@ class PPOAgent(Policy): if self.use_replay_buffer: self.memory.add(state_i, action_i, discounted_reward, state_next_i, done_i) - if self.use_replay_buffer: - if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size: - states, actions, rewards, next_states, dones, prob_actions = self.memory.sample() - return states, actions, rewards, next_states, dones, prob_actions - # convert data to torch tensors states, actions, rewards, states_next, dones, prob_actions = \ torch.tensor(state_list, dtype=torch.float).to(self.device), \ @@ -195,7 +190,11 @@ class PPOAgent(Policy): states, actions, rewards, states_next, dones, probs_action = \ self._convert_transitions_to_torch_tensors(agent_episode_history) # Optimize policy for K epochs: - for _ in range(int(self.K_epoch)): + for k_loop in range(int(self.K_epoch)): + if self.use_replay_buffer and k_loop > 0: + if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size: + states, actions, rewards, states_next, dones, probs_action = self.memory.sample() + # Evaluating actions (actor) and values (critic) logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)