Skip to content
Snippets Groups Projects
Commit 8d6304b3 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

. doc

parent 41d4b483
No related branches found
No related tags found
No related merge requests found
...@@ -165,38 +165,45 @@ class PPOAgent(Policy): ...@@ -165,38 +165,45 @@ class PPOAgent(Policy):
return states, actions, rewards, states_next, dones, prob_actions return states, actions, rewards, states_next, dones, prob_actions
def train_net(self): def train_net(self):
for handle in range(len(self.memory)): # Optimize policy for K epochs:
agent_episode_history = self.memory.get_transitions(handle) for _ in range(self.K_epoch):
if len(agent_episode_history) > 0: # All agents have to propagate their experiences made during past episode
# convert the replay buffer to torch tensors (arrays) for handle in range(len(self.memory)):
states, actions, rewards, states_next, dones, probs_action = \ # Extract agent's episode history (list of all transitions)
self._convert_transitions_to_torch_tensors(agent_episode_history) agent_episode_history = self.memory.get_transitions(handle)
if len(agent_episode_history) > 0:
# Optimize policy for K epochs: # Convert the replay buffer to torch tensors (arrays)
for _ in range(self.K_epoch): states, actions, rewards, states_next, dones, probs_action = \
# evaluating actions (actor) and values (critic) self._convert_transitions_to_torch_tensors(agent_episode_history)
# Evaluating actions (actor) and values (critic)
logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions) logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)
# finding the ratios (pi_thetas / pi_thetas_replayed): # Finding the ratios (pi_thetas / pi_thetas_replayed):
ratios = torch.exp(logprobs - probs_action.detach()) ratios = torch.exp(logprobs - probs_action.detach())
# finding Surrogate Loss: # Finding Surrogate Loos
advantages = rewards - state_values.detach() advantages = rewards - state_values.detach()
surr1 = ratios * advantages surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1. - self.surrogate_eps_clip, 1. + self.surrogate_eps_clip) * advantages surr2 = torch.clamp(ratios, 1. - self.surrogate_eps_clip, 1. + self.surrogate_eps_clip) * advantages
# The loss function is used to estimate the gardient and use the entropy function based
# heuristic to penalize the gradient function when the policy becomes deterministic this would let
# the gardient to become very flat and so the gradient is no longer useful.
loss = \ loss = \
-torch.min(surr1, surr2) \ -torch.min(surr1, surr2) \
+ self.weight_loss * self.loss_function(state_values, rewards) \ + self.weight_loss * self.loss_function(state_values, rewards) \
- self.weight_entropy * dist_entropy - self.weight_entropy * dist_entropy
# make a gradient step # Make a gradient step
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.mean().backward() loss.mean().backward()
self.optimizer.step() self.optimizer.step()
# store current loss to the agent # Transfer the current loss to the agents loss (information) for debug purpose only
self.loss = loss.mean().detach().numpy() self.loss = loss.mean().detach().numpy()
# Reset all collect transition data
self.memory.reset() self.memory.reset()
def end_episode(self, train): def end_episode(self, train):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment