Skip to content
Snippets Groups Projects
Commit 729722e3 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

very slow convergence, but convergates :-)

parent 05f62176
No related branches found
No related tags found
No related merge requests found
...@@ -336,7 +336,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -336,7 +336,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
elif agent.status == RailAgentStatus.READY_TO_DEPART: elif agent.status == RailAgentStatus.READY_TO_DEPART:
all_rewards[agent_handle] -= 5.0 all_rewards[agent_handle] -= 5.0
else: else:
if True: if False:
agent_positions = get_agent_positions(train_env) agent_positions = get_agent_positions(train_env)
for agent_handle in train_env.get_agent_handles(): for agent_handle in train_env.get_agent_handles():
agent = train_env.agents[agent_handle] agent = train_env.agents[agent_handle]
...@@ -565,7 +565,7 @@ if __name__ == "__main__": ...@@ -565,7 +565,7 @@ if __name__ == "__main__":
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int) type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=2000, type=int)
parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.05, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.05, type=float)
parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float) parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float)
......
import copy import copy
import os import os
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.distributions import Categorical from torch.distributions import Categorical
# Hyperparameters # Hyperparameters
from reinforcement_learning.policy import Policy from reinforcement_learning.policy import Policy
device = torch.device("cpu") # "cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cpu")#"cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device) print("device:", device)
...@@ -43,7 +43,8 @@ class ActorCriticModel(nn.Module): ...@@ -43,7 +43,8 @@ class ActorCriticModel(nn.Module):
nn.Tanh(), nn.Tanh(),
nn.Linear(hidsize1, hidsize2), nn.Linear(hidsize1, hidsize2),
nn.Tanh(), nn.Tanh(),
nn.Linear(hidsize2, action_size) nn.Linear(hidsize2, action_size),
nn.Softmax(dim=-1)
) )
self.critic = nn.Sequential( self.critic = nn.Sequential(
...@@ -57,13 +58,13 @@ class ActorCriticModel(nn.Module): ...@@ -57,13 +58,13 @@ class ActorCriticModel(nn.Module):
def forward(self, x): def forward(self, x):
raise NotImplementedError raise NotImplementedError
def act_prob(self, states, softmax_dim=0): def get_actor_dist(self, state):
x = self.actor(states) action_probs = self.actor(state)
prob = F.softmax(x, dim=softmax_dim) dist = Categorical(action_probs)
return prob return dist
def evaluate(self, states, actions): def evaluate(self, states, actions):
action_probs = self.act_prob(states) action_probs = self.actor(states)
dist = Categorical(action_probs) dist = Categorical(action_probs)
action_logprobs = dist.log_prob(actions) action_logprobs = dist.log_prob(actions)
dist_entropy = dist.entropy() dist_entropy = dist.entropy()
...@@ -95,11 +96,11 @@ class PPOAgent(Policy): ...@@ -95,11 +96,11 @@ class PPOAgent(Policy):
super(PPOAgent, self).__init__() super(PPOAgent, self).__init__()
# parameters # parameters
self.learning_rate = 0.1e-3 self.learning_rate = 0.1e-4
self.gamma = 0.98 self.gamma = 0.99
self.surrogate_eps_clip = 0.1 self.surrogate_eps_clip = 0.2
self.K_epoch = 3 self.K_epoch = 3
self.weight_loss = 0.9 self.weight_loss = 0.5
self.weight_entropy = 0.01 self.weight_entropy = 0.01
# objects # objects
...@@ -107,20 +108,26 @@ class PPOAgent(Policy): ...@@ -107,20 +108,26 @@ class PPOAgent(Policy):
self.loss = 0 self.loss = 0
self.actor_critic_model = ActorCriticModel(state_size, action_size) self.actor_critic_model = ActorCriticModel(state_size, action_size)
self.optimizer = optim.Adam(self.actor_critic_model.parameters(), lr=self.learning_rate) self.optimizer = optim.Adam(self.actor_critic_model.parameters(), lr=self.learning_rate)
self.lossFunction = nn.MSELoss() self.loss_function = nn.MSELoss()
def reset(self): def reset(self):
pass pass
def act(self, state, eps=None): def act(self, state, eps=None):
# sample a action to take # sample a action to take
prob = self.actor_critic_model.act_prob(torch.from_numpy(state).float()) torch_state = torch.tensor(state, dtype=torch.float).to(device)
return Categorical(prob).sample().item() dist = self.actor_critic_model.get_actor_dist(torch_state)
action = dist.sample()
return action.item()
def step(self, handle, state, action, reward, next_state, done): def step(self, handle, state, action, reward, next_state, done):
# record transitions ([state] -> [action] -> [reward, nextstate, done]) # record transitions ([state] -> [action] -> [reward, next_state, done])
prob = self.actor_critic_model.act_prob(torch.from_numpy(state).float()) torch_action = torch.tensor(action, dtype=torch.float).to(device)
transition = (state, action, reward, next_state, prob[action].item(), done) torch_state = torch.tensor(state, dtype=torch.float).to(device)
# evaluate actor
dist = self.actor_critic_model.get_actor_dist(torch_state)
action_logprobs = dist.log_prob(torch_action)
transition = (state, action, reward, next_state, action_logprobs.item(), done)
self.memory.push_transition(handle, transition) self.memory.push_transition(handle, transition)
def _convert_transitions_to_torch_tensors(self, transitions_array): def _convert_transitions_to_torch_tensors(self, transitions_array):
...@@ -177,10 +184,10 @@ class PPOAgent(Policy): ...@@ -177,10 +184,10 @@ class PPOAgent(Policy):
# finding Surrogate Loss: # finding Surrogate Loss:
advantages = rewards - state_values.detach() advantages = rewards - state_values.detach()
surr1 = ratios * advantages surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.surrogate_eps_clip, 1 + self.surrogate_eps_clip) * advantages surr2 = torch.clamp(ratios, 1. - self.surrogate_eps_clip, 1. + self.surrogate_eps_clip) * advantages
loss = \ loss = \
-torch.min(surr1, surr2) \ -torch.min(surr1, surr2) \
+ self.weight_loss * self.lossFunction(state_values, rewards) \ + self.weight_loss * self.loss_function(state_values, rewards) \
- self.weight_entropy * dist_entropy - self.weight_entropy * dist_entropy
# make a gradient step # make a gradient step
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment