Skip to content
Snippets Groups Projects
Commit 52a015e1 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactored and added new agent

parent 03748921
No related branches found
No related tags found
No related merge requests found
......@@ -172,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Double Dueling DQN policy
policy = DDDQNPolicy(state_size, get_action_size(), train_params)
if False:
if True:
policy = PPOAgent(state_size, get_action_size())
if False:
policy = DeadLockAvoidanceAgent(train_env, get_action_size())
......
......@@ -166,11 +166,6 @@ class PPOAgent(Policy):
if self.use_replay_buffer:
self.memory.add(state_i, action_i, discounted_reward, state_next_i, done_i)
if self.use_replay_buffer:
if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
states, actions, rewards, next_states, dones, prob_actions = self.memory.sample()
return states, actions, rewards, next_states, dones, prob_actions
# convert data to torch tensors
states, actions, rewards, states_next, dones, prob_actions = \
torch.tensor(state_list, dtype=torch.float).to(self.device), \
......@@ -195,7 +190,11 @@ class PPOAgent(Policy):
states, actions, rewards, states_next, dones, probs_action = \
self._convert_transitions_to_torch_tensors(agent_episode_history)
# Optimize policy for K epochs:
for _ in range(int(self.K_epoch)):
for k_loop in range(int(self.K_epoch)):
if self.use_replay_buffer and k_loop > 0:
if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
states, actions, rewards, states_next, dones, probs_action = self.memory.sample()
# Evaluating actions (actor) and values (critic)
logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment