Skip to content
Snippets Groups Projects
Commit a854eed7 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

looks good simplified

parent 8cf48167
No related branches found
No related tags found
No related merge requests found
import torch.nn as nn
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
class Policy:
def step(self, handle, state, action, reward, next_state, done):
raise NotImplementedError
......
......@@ -40,39 +40,74 @@ class DataBuffers:
self.memory.update({handle: transitions})
class PPOModelNetwork(nn.Module):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
super(DeepPPONetwork, self).__init__()
self.fc_layer_1_val = nn.Linear(state_size, hidsize1)
self.shared_network = nn.Linear(hidsize1, hidsize2)
self.fc_policy_pi = nn.Linear(hidsize2, action_size)
self.fc_value = nn.Linear(hidsize2, 1)
def forward(self, x):
val = F.relu(self.fc_layer_1_val(x))
val = F.relu(self.shared_network(val))
return val
def policy_pi_estimator(self, x, softmax_dim=0):
x = F.tanh(self.forward(x))
x = self.fc_policy_pi(x)
prob = F.softmax(x, dim=softmax_dim)
return prob
def value_estimator(self, x):
x = F.tanh(self.forward(x))
v = self.fc_value(x)
return v
# Checkpointing methods
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.shared_network.state_dict(), filename + ".fc_shared")
torch.save(self.fc_policy_pi.state_dict(), filename + ".fc_pi")
torch.save(self.fc_value.state_dict(), filename + ".fc_v")
def _load(self, obj, filename):
if os.path.exists(filename):
print(' >> ', filename)
try:
obj.load_state_dict(torch.load(filename, map_location=device))
except:
print(" >> failed!")
return obj
def load(self, filename):
print("load policy from file", filename)
self.shared_network = self._load(self.shared_network, filename + ".fc_shared")
self.fc_policy_pi = self._load(self.fc_policy_pi, filename + ".fc_pi")
self.fc_value = self._load(self.fc_value, filename + ".fc_v")
class PPOAgent(Policy):
def __init__(self, state_size, action_size):
super(PPOAgent, self).__init__()
self.memory = DataBuffers()
self.loss = 0
self.fc1 = nn.Linear(state_size, 256)
self.fc_pi = nn.Linear(256, action_size)
self.fc_v = nn.Linear(256, 1)
self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
self.value_model_network = PPOModelNetwork(state_size, action_size)
self.optimizer = optim.Adam(self.value_model_network.parameters(), lr=LEARNING_RATE)
def reset(self):
pass
def pi(self, x, softmax_dim=0):
x = F.tanh(self.fc1(x))
x = self.fc_pi(x)
prob = F.softmax(x, dim=softmax_dim)
return prob
def v(self, x):
x = F.tanh(self.fc1(x))
v = self.fc_v(x)
return v
def act(self, state, eps=None):
prob = self.pi(torch.from_numpy(state).float())
prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
m = Categorical(prob)
a = m.sample().item()
return a
def step(self, handle, state, action, reward, next_state, done):
# Record the results of the agent's action as transition
prob = self.pi(torch.from_numpy(state).float())
prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
transition = (state, action, reward, next_state, prob[action].item(), done)
self.memory.push_transition(handle, transition)
......@@ -114,8 +149,10 @@ class PPOAgent(Policy):
# run K_EPOCH optimisation steps
for i in range(K_EPOCH):
# temporal difference function / and prepare advantage function data
estimated_target_value = rewards + GAMMA * self.v(states_next) * (1.0 - dones)
difference_to_expected_value_deltas = estimated_target_value - self.v(states)
estimated_target_value = rewards + GAMMA * self.value_model_network.value_estimator(states_next) * (
1.0 - dones)
difference_to_expected_value_deltas = estimated_target_value - self.value_model_network.value_estimator(
states)
difference_to_expected_value_deltas = difference_to_expected_value_deltas.detach().numpy()
# build advantage function and convert it to torch tensor (array)
......@@ -128,7 +165,7 @@ class PPOAgent(Policy):
advantages = torch.tensor(advantage_list, dtype=torch.float)
# estimate pi_action for all state
pi_actions = self.pi(states, softmax_dim=1).gather(1, actions)
pi_actions = self.value_model_network.policy_pi_estimator(states, softmax_dim=1).gather(1, actions)
# calculate the ratios
ratios = torch.exp(torch.log(pi_actions) - torch.log(probs_action))
# Normal Policy Gradient objective
......@@ -136,7 +173,8 @@ class PPOAgent(Policy):
# clipped version of Normal Policy Gradient objective
clipped_surrogate_objective = torch.clamp(ratios * advantages, 1 - EPS_CLIP, 1 + EPS_CLIP)
# value function loss
value_loss = F.mse_loss(self.v(states), estimated_target_value.detach())
value_loss = F.mse_loss(self.value_model_network.value_estimator(states),
estimated_target_value.detach())
# loss
loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss
......@@ -156,9 +194,7 @@ class PPOAgent(Policy):
# Checkpointing methods
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.fc1.state_dict(), filename + ".fc1")
torch.save(self.fc_pi.state_dict(), filename + ".fc_pi")
torch.save(self.fc_v.state_dict(), filename + ".fc_v")
self.value_model_network.save(filename)
torch.save(self.optimizer.state_dict(), filename + ".optimizer")
def _load(self, obj, filename):
......@@ -172,9 +208,8 @@ class PPOAgent(Policy):
def load(self, filename):
print("load policy from file", filename)
self.fc1 = self._load(self.fc1, filename + ".fc1")
self.fc_pi = self._load(self.fc_pi, filename + ".fc_pi")
self.fc_v = self._load(self.fc_v, filename + ".fc_v")
self.value_model_network.load(filename)
print("load optimizer from file", filename)
self.optimizer = self._load(self.optimizer, filename + ".optimizer")
def clone(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment