From a854eed769a66293641cfb926ebc9bb5cbfa955b Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 3 Dec 2020 13:03:35 +0100
Subject: [PATCH] looks good simplified

---
 reinforcement_learning/policy.py        |  5 +-
 reinforcement_learning/ppo/ppo_agent.py | 89 +++++++++++++++++--------
 2 files changed, 63 insertions(+), 31 deletions(-)

diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index 33726d6..45889da 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -1,9 +1,6 @@
 import torch.nn as nn
 
-class Policy(nn.Module):
-    def __init__(self):
-        super(Policy, self).__init__()
-
+class Policy:
     def step(self, handle, state, action, reward, next_state, done):
         raise NotImplementedError
 
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
index d1064f3..07b2079 100644
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -40,39 +40,74 @@ class DataBuffers:
         self.memory.update({handle: transitions})
 
 
+class PPOModelNetwork(nn.Module):
+
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
+        super(DeepPPONetwork, self).__init__()
+        self.fc_layer_1_val = nn.Linear(state_size, hidsize1)
+        self.shared_network = nn.Linear(hidsize1, hidsize2)
+        self.fc_policy_pi = nn.Linear(hidsize2, action_size)
+        self.fc_value = nn.Linear(hidsize2, 1)
+
+    def forward(self, x):
+        val = F.relu(self.fc_layer_1_val(x))
+        val = F.relu(self.shared_network(val))
+        return val
+
+    def policy_pi_estimator(self, x, softmax_dim=0):
+        x = F.tanh(self.forward(x))
+        x = self.fc_policy_pi(x)
+        prob = F.softmax(x, dim=softmax_dim)
+        return prob
+
+    def value_estimator(self, x):
+        x = F.tanh(self.forward(x))
+        v = self.fc_value(x)
+        return v
+
+    # Checkpointing methods
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        torch.save(self.shared_network.state_dict(), filename + ".fc_shared")
+        torch.save(self.fc_policy_pi.state_dict(), filename + ".fc_pi")
+        torch.save(self.fc_value.state_dict(), filename + ".fc_v")
+
+    def _load(self, obj, filename):
+        if os.path.exists(filename):
+            print(' >> ', filename)
+            try:
+                obj.load_state_dict(torch.load(filename, map_location=device))
+            except:
+                print(" >> failed!")
+        return obj
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        self.shared_network = self._load(self.shared_network, filename + ".fc_shared")
+        self.fc_policy_pi = self._load(self.fc_policy_pi, filename + ".fc_pi")
+        self.fc_value = self._load(self.fc_value, filename + ".fc_v")
+
+
 class PPOAgent(Policy):
     def __init__(self, state_size, action_size):
         super(PPOAgent, self).__init__()
         self.memory = DataBuffers()
         self.loss = 0
-        self.fc1 = nn.Linear(state_size, 256)
-        self.fc_pi = nn.Linear(256, action_size)
-        self.fc_v = nn.Linear(256, 1)
-        self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
+        self.value_model_network = PPOModelNetwork(state_size, action_size)
+        self.optimizer = optim.Adam(self.value_model_network.parameters(), lr=LEARNING_RATE)
 
     def reset(self):
         pass
 
-    def pi(self, x, softmax_dim=0):
-        x = F.tanh(self.fc1(x))
-        x = self.fc_pi(x)
-        prob = F.softmax(x, dim=softmax_dim)
-        return prob
-
-    def v(self, x):
-        x = F.tanh(self.fc1(x))
-        v = self.fc_v(x)
-        return v
-
     def act(self, state, eps=None):
-        prob = self.pi(torch.from_numpy(state).float())
+        prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
         m = Categorical(prob)
         a = m.sample().item()
         return a
 
     def step(self, handle, state, action, reward, next_state, done):
         # Record the results of the agent's action as transition
-        prob = self.pi(torch.from_numpy(state).float())
+        prob = self.value_model_network.policy_pi_estimator(torch.from_numpy(state).float())
         transition = (state, action, reward, next_state, prob[action].item(), done)
         self.memory.push_transition(handle, transition)
 
@@ -114,8 +149,10 @@ class PPOAgent(Policy):
                 # run K_EPOCH optimisation steps
                 for i in range(K_EPOCH):
                     # temporal difference function / and prepare advantage function data
-                    estimated_target_value = rewards + GAMMA * self.v(states_next) * (1.0 - dones)
-                    difference_to_expected_value_deltas = estimated_target_value - self.v(states)
+                    estimated_target_value = rewards + GAMMA * self.value_model_network.value_estimator(states_next) * (
+                            1.0 - dones)
+                    difference_to_expected_value_deltas = estimated_target_value - self.value_model_network.value_estimator(
+                        states)
                     difference_to_expected_value_deltas = difference_to_expected_value_deltas.detach().numpy()
 
                     # build advantage function and convert it to torch tensor (array)
@@ -128,7 +165,7 @@ class PPOAgent(Policy):
                     advantages = torch.tensor(advantage_list, dtype=torch.float)
 
                     # estimate pi_action for all state
-                    pi_actions = self.pi(states, softmax_dim=1).gather(1, actions)
+                    pi_actions = self.value_model_network.policy_pi_estimator(states, softmax_dim=1).gather(1, actions)
                     # calculate the ratios
                     ratios = torch.exp(torch.log(pi_actions) - torch.log(probs_action))
                     # Normal Policy Gradient objective
@@ -136,7 +173,8 @@ class PPOAgent(Policy):
                     # clipped version of Normal Policy Gradient objective
                     clipped_surrogate_objective = torch.clamp(ratios * advantages, 1 - EPS_CLIP, 1 + EPS_CLIP)
                     # value function loss
-                    value_loss = F.mse_loss(self.v(states), estimated_target_value.detach())
+                    value_loss = F.mse_loss(self.value_model_network.value_estimator(states),
+                                            estimated_target_value.detach())
                     # loss
                     loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss
 
@@ -156,9 +194,7 @@ class PPOAgent(Policy):
     # Checkpointing methods
     def save(self, filename):
         # print("Saving model from checkpoint:", filename)
-        torch.save(self.fc1.state_dict(), filename + ".fc1")
-        torch.save(self.fc_pi.state_dict(), filename + ".fc_pi")
-        torch.save(self.fc_v.state_dict(), filename + ".fc_v")
+        self.value_model_network.save(filename)
         torch.save(self.optimizer.state_dict(), filename + ".optimizer")
 
     def _load(self, obj, filename):
@@ -172,9 +208,8 @@ class PPOAgent(Policy):
 
     def load(self, filename):
         print("load policy from file", filename)
-        self.fc1 = self._load(self.fc1, filename + ".fc1")
-        self.fc_pi = self._load(self.fc_pi, filename + ".fc_pi")
-        self.fc_v = self._load(self.fc_v, filename + ".fc_v")
+        self.value_model_network.load(filename)
+        print("load optimizer from file", filename)
         self.optimizer = self._load(self.optimizer, filename + ".optimizer")
 
     def clone(self):
-- 
GitLab