Skip to content
Snippets Groups Projects
Commit 70393c79 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

inital ppo version actor critic

parent 769f25ec
No related branches found
No related tags found
No related merge requests found
...@@ -314,7 +314,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -314,7 +314,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
next_obs, all_rewards, done, info = train_env.step(action_dict) next_obs, all_rewards, done, info = train_env.step(action_dict)
# Reward shaping .Dead-lock .NotMoving .NotStarted # Reward shaping .Dead-lock .NotMoving .NotStarted
if True: if False:
agent_positions = get_agent_positions(train_env) agent_positions = get_agent_positions(train_env)
for agent_handle in train_env.get_agent_handles(): for agent_handle in train_env.get_agent_handles():
agent = train_env.agents[agent_handle] agent = train_env.agents[agent_handle]
...@@ -335,6 +335,17 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -335,6 +335,17 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
all_rewards[agent_handle] -= 5.0 all_rewards[agent_handle] -= 5.0
elif agent.status == RailAgentStatus.READY_TO_DEPART: elif agent.status == RailAgentStatus.READY_TO_DEPART:
all_rewards[agent_handle] -= 5.0 all_rewards[agent_handle] -= 5.0
else:
if True:
agent_positions = get_agent_positions(train_env)
for agent_handle in train_env.get_agent_handles():
agent = train_env.agents[agent_handle]
act = action_dict.get(agent_handle, RailEnvActions.MOVE_FORWARD)
if agent.status == RailAgentStatus.ACTIVE:
if done[agent_handle] == False:
if check_for_dealock(agent_handle, train_env, agent_positions):
all_rewards[agent_handle] -= 1000.0
done[agent_handle] = True
step_timer.end() step_timer.end()
...@@ -548,13 +559,13 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): ...@@ -548,13 +559,13 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5000, type=int) parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=25000, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
type=int) type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int) type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int)
parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.05, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.05, type=float)
parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float) parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float)
......
...@@ -15,8 +15,10 @@ GAMMA = 0.98 ...@@ -15,8 +15,10 @@ GAMMA = 0.98
LAMBDA = 0.9 LAMBDA = 0.9
SURROGATE_EPS_CLIP = 0.01 SURROGATE_EPS_CLIP = 0.01
K_EPOCH = 3 K_EPOCH = 3
WEIGHT_LOSS = 0.5
WEIGHT_ENTROPY = 0.01
device = torch.device("cpu")#"cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cpu") # "cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device) print("device:", device)
...@@ -40,85 +42,46 @@ class DataBuffers: ...@@ -40,85 +42,46 @@ class DataBuffers:
self.memory.update({handle: transitions}) self.memory.update({handle: transitions})
class GlobalModel(nn.Module): class ActorCriticModel(nn.Module):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
super(GlobalModel, self).__init__()
self._layer_1 = nn.Linear(state_size, hidsize1)
self.global_network = nn.Linear(hidsize1, hidsize2)
def get_model(self): def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
return self.global_network super(ActorCriticModel, self).__init__()
self.actor = nn.Sequential(
nn.Linear(state_size, hidsize1),
nn.Tanh(),
nn.Linear(hidsize1, hidsize2),
nn.Tanh(),
nn.Linear(hidsize2, action_size)
)
self.critic = nn.Sequential(
nn.Linear(state_size, hidsize1),
nn.Tanh(),
nn.Linear(hidsize1, hidsize2),
nn.Tanh(),
nn.Linear(hidsize2, 1)
)
def forward(self, x): def forward(self, x):
val = F.relu(self._layer_1(x)) raise NotImplementedError
val = F.relu(self.global_network(val))
return val
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.global_network.state_dict(), filename + ".global")
def _load(self, obj, filename):
if os.path.exists(filename):
print(' >> ', filename)
try:
obj.load_state_dict(torch.load(filename, map_location=device))
except:
print(" >> failed!")
return obj
def load(self, filename):
self.global_network = self._load(self.global_network, filename + ".global")
class PolicyNetwork(nn.Module):
def __init__(self, state_size, action_size, global_network, hidsize1=128, hidsize2=128): def act_prob(self, states, softmax_dim=0):
super(PolicyNetwork, self).__init__() x = self.actor(states)
self.global_network = global_network
self.policy_network = nn.Linear(hidsize2, action_size)
def forward(self, x, softmax_dim=0):
x = F.tanh(self.global_network.forward(x))
x = self.policy_network(x)
prob = F.softmax(x, dim=softmax_dim) prob = F.softmax(x, dim=softmax_dim)
return prob return prob
# Checkpointing methods def evaluate(self, states, actions):
def save(self, filename): action_probs = self.act_prob(states)
# print("Saving model from checkpoint:", filename) dist = Categorical(action_probs)
torch.save(self.policy_network.state_dict(), filename + ".policy") action_logprobs = dist.log_prob(actions)
dist_entropy = dist.entropy()
def _load(self, obj, filename): state_value = self.critic(states)
if os.path.exists(filename): return action_logprobs, torch.squeeze(state_value), dist_entropy
print(' >> ', filename)
try:
obj.load_state_dict(torch.load(filename, map_location=device))
except:
print(" >> failed!")
return obj
def load(self, filename):
print("load policy from file", filename)
self.policy_network = self._load(self.policy_network, filename + ".policy")
class ValueNetwork(nn.Module):
def __init__(self, state_size, action_size, global_network, hidsize1=128, hidsize2=128):
super(ValueNetwork, self).__init__()
self.global_network = global_network
self.value_network = nn.Linear(hidsize2, 1)
def forward(self, x):
x = F.tanh(self.global_network.forward(x))
v = self.value_network(x)
return v
# Checkpointing methods
def save(self, filename): def save(self, filename):
# print("Saving model from checkpoint:", filename) # print("Saving model from checkpoint:", filename)
torch.save(self.value_network.state_dict(), filename + ".value") torch.save(self.actor.state_dict(), filename + ".actor")
torch.save(self.critic.state_dict(), filename + ".value")
def _load(self, obj, filename): def _load(self, obj, filename):
if os.path.exists(filename): if os.path.exists(filename):
...@@ -130,69 +93,68 @@ class ValueNetwork(nn.Module): ...@@ -130,69 +93,68 @@ class ValueNetwork(nn.Module):
return obj return obj
def load(self, filename): def load(self, filename):
self.value_network = self._load(self.value_network, filename + ".value") print("load policy from file", filename)
self.actor = self._load(self.actor, filename + ".actor")
self.critic = self._load(self.critic, filename + ".critic")
class PPOAgent(Policy): class PPOAgent(Policy):
def __init__(self, state_size, action_size): def __init__(self, state_size, action_size):
super(PPOAgent, self).__init__() super(PPOAgent, self).__init__()
# create the data buffer - collects all transitions (state, action, reward, next_state, action_prob, done)
# each agent owns its own buffer
self.memory = DataBuffers() self.memory = DataBuffers()
# signal - stores the current loss
self.loss = 0 self.loss = 0
# create the global, shared deep neuronal network self.actor_critic_model = ActorCriticModel(state_size, action_size)
self.global_network = GlobalModel(state_size, action_size) self.optimizer = optim.Adam(self.actor_critic_model.parameters(), lr=LEARNING_RATE)
# create the "critic" or value network self.lossFunction = nn.MSELoss()
self.value_network = ValueNetwork(state_size, action_size, self.global_network)
# create the "actor" or policy network
self.policy_network = PolicyNetwork(state_size, action_size, self.global_network)
# create for each network a optimizer
self.value_optimizer = optim.Adam(self.value_network.parameters(), lr=LEARNING_RATE)
self.policy_optimizer = optim.Adam(self.policy_network.parameters(), lr=LEARNING_RATE)
def reset(self): def reset(self):
pass pass
def act(self, state, eps=None): def act(self, state, eps=None):
prob = self.policy_network(torch.from_numpy(state).float()) # sample a action to take
m = Categorical(prob) prob = self.actor_critic_model.act_prob(torch.from_numpy(state).float())
a = m.sample().item() return Categorical(prob).sample().item()
return a
def step(self, handle, state, action, reward, next_state, done): def step(self, handle, state, action, reward, next_state, done):
# Record the results of the agent's action as transition # record transitions ([state] -> [action] -> [reward, nextstate, done])
prob = self.policy_network(torch.from_numpy(state).float()) prob = self.actor_critic_model.act_prob(torch.from_numpy(state).float())
transition = (state, action, reward, next_state, prob[action].item(), done) transition = (state, action, reward, next_state, prob[action].item(), done)
self.memory.push_transition(handle, transition) self.memory.push_transition(handle, transition)
def _convert_transitions_to_torch_tensors(self, transitions_array): def _convert_transitions_to_torch_tensors(self, transitions_array):
# build empty lists(arrays)
state_list, action_list, reward_list, state_next_list, prob_a_list, done_list = [], [], [], [], [], [] state_list, action_list, reward_list, state_next_list, prob_a_list, done_list = [], [], [], [], [], []
total_reward = 0
for transition in transitions_array: # set discounted_reward to zero
discounted_reward = 0
for transition in transitions_array[::-1]:
state_i, action_i, reward_i, state_next_i, prob_action_i, done_i = transition state_i, action_i, reward_i, state_next_i, prob_action_i, done_i = transition
state_list.append(state_i) state_list.insert(0, state_i)
action_list.append([action_i]) action_list.insert(0, action_i)
total_reward += reward_i
if done_i: if done_i:
reward_list.append([max(1.0, total_reward)]) discounted_reward = 0
done_list.append([1]) done_list.insert(0, 1)
else: else:
reward_list.append([0]) discounted_reward = reward_i + GAMMA * discounted_reward
done_list.append([0]) done_list.insert(0, 0)
state_next_list.append(state_next_i) reward_list.insert(0, discounted_reward)
prob_a_list.append([prob_action_i]) state_next_list.insert(0, state_next_i)
prob_a_list.insert(0, prob_action_i)
state, action, reward, s_next, done, prob_action = torch.tensor(state_list, dtype=torch.float), \
torch.tensor(action_list), \ # convert data to torch tensors
torch.tensor(reward_list), \ states, actions, rewards, states_next, dones, prob_actions = \
torch.tensor(state_next_list, dtype=torch.float), \ torch.tensor(state_list, dtype=torch.float).to(device), \
torch.tensor(done_list, dtype=torch.float), \ torch.tensor(action_list).to(device), \
torch.tensor(prob_a_list) torch.tensor(reward_list, dtype=torch.float).to(device), \
torch.tensor(state_next_list, dtype=torch.float).to(device), \
return state, action, reward, s_next, done, prob_action torch.tensor(done_list, dtype=torch.float).to(device), \
torch.tensor(prob_a_list).to(device)
# standard-normalize rewards
rewards = (rewards - rewards.mean()) / (rewards.std() + 1.e-5)
return states, actions, rewards, states_next, dones, prob_actions
def train_net(self): def train_net(self):
for handle in range(len(self.memory)): for handle in range(len(self.memory)):
...@@ -202,48 +164,29 @@ class PPOAgent(Policy): ...@@ -202,48 +164,29 @@ class PPOAgent(Policy):
states, actions, rewards, states_next, dones, probs_action = \ states, actions, rewards, states_next, dones, probs_action = \
self._convert_transitions_to_torch_tensors(agent_episode_history) self._convert_transitions_to_torch_tensors(agent_episode_history)
# run K_EPOCH optimisation steps # Optimize policy for K epochs:
for i in range(K_EPOCH): for _ in range(K_EPOCH):
# temporal difference function / and prepare advantage function data # evaluating actions (actor) and values (critic)
estimated_target_value = \ logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)
rewards + GAMMA * self.value_network(states_next) * (1.0 - dones)
difference_to_expected_value_deltas = \ # finding the ratios (pi_thetas / pi_thetas_replayed):
estimated_target_value - self.value_network(states) ratios = torch.exp(logprobs - probs_action.detach())
difference_to_expected_value_deltas = difference_to_expected_value_deltas.detach().numpy()
# finding Surrogate Loss:
# build advantage function and convert it to torch tensor (array) advantages = rewards - state_values.detach()
advantage_list = [] surr1 = ratios * advantages
advantage_value = 0.0 surr2 = torch.clamp(ratios, 1 - SURROGATE_EPS_CLIP, 1 + SURROGATE_EPS_CLIP) * advantages
for difference_to_expected_value_t in difference_to_expected_value_deltas[::-1]: loss = \
advantage_value = LAMBDA * advantage_value + difference_to_expected_value_t[0] -torch.min(surr1, surr2) \
advantage_list.append([advantage_value]) + WEIGHT_LOSS * self.lossFunction(state_values, rewards) \
advantage_list.reverse() - WEIGHT_ENTROPY * dist_entropy
advantages = torch.tensor(advantage_list, dtype=torch.float)
# make a gradient step
# estimate pi_action for all state self.optimizer.zero_grad()
pi_actions = self.policy_network.forward(states, softmax_dim=1).gather(1, actions)
# calculate the ratios
ratios = torch.exp(torch.log(pi_actions) - torch.log(probs_action))
# Normal Policy Gradient objective
surrogate_objective = ratios * advantages
# clipped version of Normal Policy Gradient objective
clipped_surrogate_objective = torch.clamp(ratios * advantages,
1 - SURROGATE_EPS_CLIP,
1 + SURROGATE_EPS_CLIP)
# create value loss function
value_loss = F.smooth_l1_loss(self.value_network(states),
estimated_target_value.detach())
# create final loss function
loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss
# update policy ("actor") and value ("critic") networks
self.value_optimizer.zero_grad()
self.policy_optimizer.zero_grad()
loss.mean().backward() loss.mean().backward()
self.value_optimizer.step() self.optimizer.step()
self.policy_optimizer.step()
# store current loss # store current loss to the agent
self.loss = loss.mean().detach().numpy() self.loss = loss.mean().detach().numpy()
self.memory.reset() self.memory.reset()
...@@ -252,12 +195,11 @@ class PPOAgent(Policy): ...@@ -252,12 +195,11 @@ class PPOAgent(Policy):
if train: if train:
self.train_net() self.train_net()
# Checkpointing methods
def save(self, filename): def save(self, filename):
self.global_network.save(filename) # print("Saving model from checkpoint:", filename)
self.value_network.save(filename) self.actor_critic_model.save(filename)
self.policy_network.save(filename) torch.save(self.optimizer.state_dict(), filename + ".optimizer")
torch.save(self.value_optimizer.state_dict(), filename + ".value_optimizer")
torch.save(self.policy_optimizer.state_dict(), filename + ".policy_optimizer")
def _load(self, obj, filename): def _load(self, obj, filename):
if os.path.exists(filename): if os.path.exists(filename):
...@@ -269,16 +211,13 @@ class PPOAgent(Policy): ...@@ -269,16 +211,13 @@ class PPOAgent(Policy):
return obj return obj
def load(self, filename): def load(self, filename):
self.global_network.load(filename) print("load policy from file", filename)
self.value_network.load(filename) self.actor_critic_model.load(filename)
self.policy_network.load(filename) print("load optimizer from file", filename)
self.value_optimizer = self._load(self.value_optimizer, filename + ".value_optimizer") self.optimizer = self._load(self.optimizer, filename + ".optimizer")
self.policy_optimizer = self._load(self.policy_optimizer, filename + ".policy_optimizer")
def clone(self): def clone(self):
policy = PPOAgent(self.state_size, self.action_size) policy = PPOAgent(self.state_size, self.action_size)
policy.value_network = copy.deepcopy(self.value_network) policy.actor_critic_model = copy.deepcopy(self.actor_critic_model)
policy.policy_network = copy.deepcopy(self.policy_network) policy.optimizer = copy.deepcopy(self.optimizer)
policy.value_optimizer = copy.deepcopy(self.value_optimizer)
policy.policy_optimizer = copy.deepcopy(self.policy_optimizer)
return self return self
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment