From 872732881018489ca0da6b92dc541730efa972ee Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 3 Dec 2020 14:36:39 +0100
Subject: [PATCH] small fix in object

---
 reinforcement_learning/ppo/ppo_agent.py | 161 ++++++++++++++++--------
 1 file changed, 111 insertions(+), 50 deletions(-)

diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
index f213464..3322c82 100644
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -40,37 +40,85 @@ class DataBuffers:
         self.memory.update({handle: transitions})
 
 
-class PPOModelNetwork(nn.Module):
-
+class GlobalModel(nn.Module):
     def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
-        super(PPOModelNetwork, self).__init__()
-        self.fc_layer_1_val = nn.Linear(state_size, hidsize1)
-        self.shared_network = nn.Linear(hidsize1, hidsize2)
-        self.fc_policy_pi = nn.Linear(hidsize2, action_size)
-        self.fc_value = nn.Linear(hidsize2, 1)
+        super(GlobalModel, self).__init__()
+        self._layer_1 = nn.Linear(state_size, hidsize1)
+        self.global_network = nn.Linear(hidsize1, hidsize2)
+
+    def get_model(self):
+        return self.global_network
 
     def forward(self, x):
-        val = F.relu(self.fc_layer_1_val(x))
-        val = F.relu(self.shared_network(val))
+        val = F.relu(self._layer_1(x))
+        val = F.relu(self.global_network(val))
         return val
 
-    def policy_pi_estimator(self, x, softmax_dim=0):
-        x = F.tanh(self.forward(x))
-        x = self.fc_policy_pi(x)
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        torch.save(self.global_network.state_dict(), filename + ".global")
+
+    def _load(self, obj, filename):
+        if os.path.exists(filename):
+            print(' >> ', filename)
+            try:
+                obj.load_state_dict(torch.load(filename, map_location=device))
+            except:
+                print(" >> failed!")
+        return obj
+
+    def load(self, filename):
+        self.global_network = self._load(self.global_network, filename + ".global")
+
+
+class PolicyNetwork(nn.Module):
+
+    def __init__(self, state_size, action_size, global_network, hidsize1=128, hidsize2=128):
+        super(PolicyNetwork, self).__init__()
+        self.global_network = global_network
+        self.policy_network = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x, softmax_dim=0):
+        x = F.tanh(self.global_network.forward(x))
+        x = self.policy_network(x)
         prob = F.softmax(x, dim=softmax_dim)
         return prob
 
-    def value_estimator(self, x):
-        x = F.tanh(self.forward(x))
-        v = self.fc_value(x)
+    # Checkpointing methods
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        torch.save(self.policy_network.state_dict(), filename + ".policy")
+
+    def _load(self, obj, filename):
+        if os.path.exists(filename):
+            print(' >> ', filename)
+            try:
+                obj.load_state_dict(torch.load(filename, map_location=device))
+            except:
+                print(" >> failed!")
+        return obj
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        self.policy_network = self._load(self.policy_network, filename + ".policy")
+
+
+class ValueNetwork(nn.Module):
+
+    def __init__(self, state_size, action_size, global_network, hidsize1=128, hidsize2=128):
+        super(ValueNetwork, self).__init__()
+        self.global_network = global_network
+        self.value_network = nn.Linear(hidsize2, 1)
+
+    def forward(self, x):
+        x = F.tanh(self.global_network.forward(x))
+        v = self.value_network(x)
         return v
 
     # Checkpointing methods
     def save(self, filename):
         # print("Saving model from checkpoint:", filename)
-        torch.save(self.shared_network.state_dict(), filename + ".fc_shared")
-        torch.save(self.fc_policy_pi.state_dict(), filename + ".fc_pi")
-        torch.save(self.fc_value.state_dict(), filename + ".fc_v")
+        torch.save(self.value_network.state_dict(), filename + ".value")
 
     def _load(self, obj, filename):
         if os.path.exists(filename):
@@ -82,32 +130,40 @@ class PPOModelNetwork(nn.Module):
         return obj
 
     def load(self, filename):
-        print("load policy from file", filename)
-        self.shared_network = self._load(self.shared_network, filename + ".fc_shared")
-        self.fc_policy_pi = self._load(self.fc_policy_pi, filename + ".fc_pi")
-        self.fc_value = self._load(self.fc_value, filename + ".fc_v")
+        self.value_network = self._load(self.value_network, filename + ".value")
 
 
 class PPOAgent(Policy):
     def __init__(self, state_size, action_size):
         super(PPOAgent, self).__init__()
+        # create the data buffer - collects all transitions (state, action, reward, next_state, action_prob, done)
+        # each agent owns its own buffer
         self.memory = DataBuffers()
+        # signal - stores the current loss
         self.loss = 0
-        self.value_model_network = PPOModelNetwork(state_size, action_size)
-        self.optimizer = optim.Adam(self.value_model_network.parameters(), lr=LEARNING_RATE)
+        # create the global, shared deep neuronal network
+        self.global_network = GlobalModel(state_size, action_size)
+        # create the "critic" or value network
+        self.value_network = ValueNetwork(state_size, action_size, self.global_network)
+        # create the "actor" or policy network
+        self.policy_network = PolicyNetwork(state_size, action_size, self.global_network)
+        # create for each network a optimizer
+        self.value_optimizer = optim.Adam(self.value_network.parameters(), lr=LEARNING_RATE)
+        self.policy_optimizer = optim.Adam(self.policy_network.parameters(), lr=LEARNING_RATE)
+
 
     def reset(self):
         pass
 
     def act(self, state, eps=None):
-        prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
+        prob = self.policy_network(torch.from_numpy(state).float())
         m = Categorical(prob)
         a = m.sample().item()
         return a
 
     def step(self, handle, state, action, reward, next_state, done):
         # Record the results of the agent's action as transition
-        prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
+        prob = self.policy_network(torch.from_numpy(state).float())
         transition = (state, action, reward, next_state, prob[action].item(), done)
         self.memory.push_transition(handle, transition)
 
@@ -149,10 +205,10 @@ class PPOAgent(Policy):
                 # run K_EPOCH optimisation steps
                 for i in range(K_EPOCH):
                     # temporal difference function / and prepare advantage function data
-                    estimated_target_value = rewards + GAMMA * self.value_model_network.value_estimator(states_next) * (
-                            1.0 - dones)
-                    difference_to_expected_value_deltas = estimated_target_value - self.value_model_network.value_estimator(
-                        states)
+                    estimated_target_value = \
+                        rewards + GAMMA * self.value_network(states_next) * (1.0 - dones)
+                    difference_to_expected_value_deltas = \
+                        estimated_target_value - self.value_network(states)
                     difference_to_expected_value_deltas = difference_to_expected_value_deltas.detach().numpy()
 
                     # build advantage function and convert it to torch tensor (array)
@@ -165,37 +221,41 @@ class PPOAgent(Policy):
                     advantages = torch.tensor(advantage_list, dtype=torch.float)
 
                     # estimate pi_action for all state
-                    pi_actions = self.value_model_network.policy_pi_estimator(states, softmax_dim=1).gather(1, actions)
+                    pi_actions = self.policy_network.forward(states, softmax_dim=1).gather(1, actions)
                     # calculate the ratios
                     ratios = torch.exp(torch.log(pi_actions) - torch.log(probs_action))
                     # Normal Policy Gradient objective
                     surrogate_objective = ratios * advantages
                     # clipped version of Normal Policy Gradient objective
                     clipped_surrogate_objective = torch.clamp(ratios * advantages, 1 - EPS_CLIP, 1 + EPS_CLIP)
-                    # value function loss
-                    value_loss = F.mse_loss(self.value_model_network.value_estimator(states),
+                    # create value loss function
+                    value_loss = F.mse_loss(self.value_network(states),
                                             estimated_target_value.detach())
-                    # loss
+                    # create final loss function
                     loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss
 
-                    # update policy and actor networks
-                    self.optimizer.zero_grad()
+                    # update policy ("actor") and value ("critic") networks
+                    self.value_optimizer.zero_grad()
+                    self.policy_optimizer.zero_grad()
                     loss.mean().backward()
-                    self.optimizer.step()
+                    self.value_optimizer.step()
+                    self.policy_optimizer.step()
 
-                    # store current loss to the agent
+                    # store current loss
                     self.loss = loss.mean().detach().numpy()
+
         self.memory.reset()
 
     def end_episode(self, train):
         if train:
             self.train_net()
 
-    # Checkpointing methods
     def save(self, filename):
-        # print("Saving model from checkpoint:", filename)
-        self.value_model_network.save(filename)
-        torch.save(self.optimizer.state_dict(), filename + ".optimizer")
+        self.global_network.save(filename)
+        self.value_network.save(filename)
+        self.policy_network.save(filename)
+        torch.save(self.value_optimizer.state_dict(), filename + ".value_optimizer")
+        torch.save(self.policy_optimizer.state_dict(), filename + ".policy_optimizer")
 
     def _load(self, obj, filename):
         if os.path.exists(filename):
@@ -207,15 +267,16 @@ class PPOAgent(Policy):
         return obj
 
     def load(self, filename):
-        print("load policy from file", filename)
-        self.value_model_network.load(filename)
-        print("load optimizer from file", filename)
-        self.optimizer = self._load(self.optimizer, filename + ".optimizer")
+        self.global_network.load(filename)
+        self.value_network.load(filename)
+        self.policy_network.load(filename)
+        self.value_optimizer = self._load(self.value_optimizer, filename + ".value_optimizer")
+        self.policy_optimizer = self._load(self.policy_optimizer, filename + ".policy_optimizer")
 
     def clone(self):
         policy = PPOAgent(self.state_size, self.action_size)
-        policy.fc1 = copy.deepcopy(self.fc1)
-        policy.fc_pi = copy.deepcopy(self.fc_pi)
-        policy.fc_v = copy.deepcopy(self.fc_v)
-        policy.optimizer = copy.deepcopy(self.optimizer)
+        policy.value_network = copy.deepcopy(self.value_network)
+        policy.policy_network = copy.deepcopy(self.policy_network)
+        policy.value_optimizer = copy.deepcopy(self.value_optimizer)
+        policy.policy_optimizer = copy.deepcopy(self.policy_optimizer)
         return self
-- 
GitLab