Skip to content
Snippets Groups Projects
Commit 87273288 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

small fix in object

parent f1cb653e
No related branches found
No related tags found
No related merge requests found
......@@ -40,37 +40,85 @@ class DataBuffers:
self.memory.update({handle: transitions})
class PPOModelNetwork(nn.Module):
class GlobalModel(nn.Module):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
super(PPOModelNetwork, self).__init__()
self.fc_layer_1_val = nn.Linear(state_size, hidsize1)
self.shared_network = nn.Linear(hidsize1, hidsize2)
self.fc_policy_pi = nn.Linear(hidsize2, action_size)
self.fc_value = nn.Linear(hidsize2, 1)
super(GlobalModel, self).__init__()
self._layer_1 = nn.Linear(state_size, hidsize1)
self.global_network = nn.Linear(hidsize1, hidsize2)
def get_model(self):
return self.global_network
def forward(self, x):
val = F.relu(self.fc_layer_1_val(x))
val = F.relu(self.shared_network(val))
val = F.relu(self._layer_1(x))
val = F.relu(self.global_network(val))
return val
def policy_pi_estimator(self, x, softmax_dim=0):
x = F.tanh(self.forward(x))
x = self.fc_policy_pi(x)
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.global_network.state_dict(), filename + ".global")
def _load(self, obj, filename):
if os.path.exists(filename):
print(' >> ', filename)
try:
obj.load_state_dict(torch.load(filename, map_location=device))
except:
print(" >> failed!")
return obj
def load(self, filename):
self.global_network = self._load(self.global_network, filename + ".global")
class PolicyNetwork(nn.Module):
def __init__(self, state_size, action_size, global_network, hidsize1=128, hidsize2=128):
super(PolicyNetwork, self).__init__()
self.global_network = global_network
self.policy_network = nn.Linear(hidsize2, action_size)
def forward(self, x, softmax_dim=0):
x = F.tanh(self.global_network.forward(x))
x = self.policy_network(x)
prob = F.softmax(x, dim=softmax_dim)
return prob
def value_estimator(self, x):
x = F.tanh(self.forward(x))
v = self.fc_value(x)
# Checkpointing methods
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.policy_network.state_dict(), filename + ".policy")
def _load(self, obj, filename):
if os.path.exists(filename):
print(' >> ', filename)
try:
obj.load_state_dict(torch.load(filename, map_location=device))
except:
print(" >> failed!")
return obj
def load(self, filename):
print("load policy from file", filename)
self.policy_network = self._load(self.policy_network, filename + ".policy")
class ValueNetwork(nn.Module):
def __init__(self, state_size, action_size, global_network, hidsize1=128, hidsize2=128):
super(ValueNetwork, self).__init__()
self.global_network = global_network
self.value_network = nn.Linear(hidsize2, 1)
def forward(self, x):
x = F.tanh(self.global_network.forward(x))
v = self.value_network(x)
return v
# Checkpointing methods
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.shared_network.state_dict(), filename + ".fc_shared")
torch.save(self.fc_policy_pi.state_dict(), filename + ".fc_pi")
torch.save(self.fc_value.state_dict(), filename + ".fc_v")
torch.save(self.value_network.state_dict(), filename + ".value")
def _load(self, obj, filename):
if os.path.exists(filename):
......@@ -82,32 +130,40 @@ class PPOModelNetwork(nn.Module):
return obj
def load(self, filename):
print("load policy from file", filename)
self.shared_network = self._load(self.shared_network, filename + ".fc_shared")
self.fc_policy_pi = self._load(self.fc_policy_pi, filename + ".fc_pi")
self.fc_value = self._load(self.fc_value, filename + ".fc_v")
self.value_network = self._load(self.value_network, filename + ".value")
class PPOAgent(Policy):
def __init__(self, state_size, action_size):
super(PPOAgent, self).__init__()
# create the data buffer - collects all transitions (state, action, reward, next_state, action_prob, done)
# each agent owns its own buffer
self.memory = DataBuffers()
# signal - stores the current loss
self.loss = 0
self.value_model_network = PPOModelNetwork(state_size, action_size)
self.optimizer = optim.Adam(self.value_model_network.parameters(), lr=LEARNING_RATE)
# create the global, shared deep neuronal network
self.global_network = GlobalModel(state_size, action_size)
# create the "critic" or value network
self.value_network = ValueNetwork(state_size, action_size, self.global_network)
# create the "actor" or policy network
self.policy_network = PolicyNetwork(state_size, action_size, self.global_network)
# create for each network a optimizer
self.value_optimizer = optim.Adam(self.value_network.parameters(), lr=LEARNING_RATE)
self.policy_optimizer = optim.Adam(self.policy_network.parameters(), lr=LEARNING_RATE)
def reset(self):
pass
def act(self, state, eps=None):
prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
prob = self.policy_network(torch.from_numpy(state).float())
m = Categorical(prob)
a = m.sample().item()
return a
def step(self, handle, state, action, reward, next_state, done):
# Record the results of the agent's action as transition
prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
prob = self.policy_network(torch.from_numpy(state).float())
transition = (state, action, reward, next_state, prob[action].item(), done)
self.memory.push_transition(handle, transition)
......@@ -149,10 +205,10 @@ class PPOAgent(Policy):
# run K_EPOCH optimisation steps
for i in range(K_EPOCH):
# temporal difference function / and prepare advantage function data
estimated_target_value = rewards + GAMMA * self.value_model_network.value_estimator(states_next) * (
1.0 - dones)
difference_to_expected_value_deltas = estimated_target_value - self.value_model_network.value_estimator(
states)
estimated_target_value = \
rewards + GAMMA * self.value_network(states_next) * (1.0 - dones)
difference_to_expected_value_deltas = \
estimated_target_value - self.value_network(states)
difference_to_expected_value_deltas = difference_to_expected_value_deltas.detach().numpy()
# build advantage function and convert it to torch tensor (array)
......@@ -165,37 +221,41 @@ class PPOAgent(Policy):
advantages = torch.tensor(advantage_list, dtype=torch.float)
# estimate pi_action for all state
pi_actions = self.value_model_network.policy_pi_estimator(states, softmax_dim=1).gather(1, actions)
pi_actions = self.policy_network.forward(states, softmax_dim=1).gather(1, actions)
# calculate the ratios
ratios = torch.exp(torch.log(pi_actions) - torch.log(probs_action))
# Normal Policy Gradient objective
surrogate_objective = ratios * advantages
# clipped version of Normal Policy Gradient objective
clipped_surrogate_objective = torch.clamp(ratios * advantages, 1 - EPS_CLIP, 1 + EPS_CLIP)
# value function loss
value_loss = F.mse_loss(self.value_model_network.value_estimator(states),
# create value loss function
value_loss = F.mse_loss(self.value_network(states),
estimated_target_value.detach())
# loss
# create final loss function
loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss
# update policy and actor networks
self.optimizer.zero_grad()
# update policy ("actor") and value ("critic") networks
self.value_optimizer.zero_grad()
self.policy_optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
self.value_optimizer.step()
self.policy_optimizer.step()
# store current loss to the agent
# store current loss
self.loss = loss.mean().detach().numpy()
self.memory.reset()
def end_episode(self, train):
if train:
self.train_net()
# Checkpointing methods
def save(self, filename):
# print("Saving model from checkpoint:", filename)
self.value_model_network.save(filename)
torch.save(self.optimizer.state_dict(), filename + ".optimizer")
self.global_network.save(filename)
self.value_network.save(filename)
self.policy_network.save(filename)
torch.save(self.value_optimizer.state_dict(), filename + ".value_optimizer")
torch.save(self.policy_optimizer.state_dict(), filename + ".policy_optimizer")
def _load(self, obj, filename):
if os.path.exists(filename):
......@@ -207,15 +267,16 @@ class PPOAgent(Policy):
return obj
def load(self, filename):
print("load policy from file", filename)
self.value_model_network.load(filename)
print("load optimizer from file", filename)
self.optimizer = self._load(self.optimizer, filename + ".optimizer")
self.global_network.load(filename)
self.value_network.load(filename)
self.policy_network.load(filename)
self.value_optimizer = self._load(self.value_optimizer, filename + ".value_optimizer")
self.policy_optimizer = self._load(self.policy_optimizer, filename + ".policy_optimizer")
def clone(self):
policy = PPOAgent(self.state_size, self.action_size)
policy.fc1 = copy.deepcopy(self.fc1)
policy.fc_pi = copy.deepcopy(self.fc_pi)
policy.fc_v = copy.deepcopy(self.fc_v)
policy.optimizer = copy.deepcopy(self.optimizer)
policy.value_network = copy.deepcopy(self.value_network)
policy.policy_network = copy.deepcopy(self.policy_network)
policy.value_optimizer = copy.deepcopy(self.value_optimizer)
policy.policy_optimizer = copy.deepcopy(self.policy_optimizer)
return self
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment