From 168d6728eb8f2e373d3557690d5308982370251e Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Sat, 12 Dec 2020 23:01:33 +0100
Subject: [PATCH] convergates much faster :-)

---
 reinforcement_learning/multi_agent_training.py | 2 +-
 reinforcement_learning/ppo_agent.py            | 4 +++-
 run.py                                         | 7 ++++---
 3 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index cc6cfce..3237ce7 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -172,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, action_size, train_params)
-    if False:
+    if True:
         policy = PPOAgent(state_size, action_size)
     # Load existing policy
     if train_params.load_policy is not "":
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
index e97b265..703c956 100644
--- a/reinforcement_learning/ppo_agent.py
+++ b/reinforcement_learning/ppo_agent.py
@@ -9,7 +9,7 @@ from torch.distributions import Categorical
 # Hyperparameters
 from reinforcement_learning.policy import Policy
 
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")#"cuda:0" if torch.cuda.is_available() else "cpu")
 print("device:", device)
 
 
@@ -145,8 +145,10 @@ class PPOAgent(Policy):
             if done_i:
                 discounted_reward = 0
                 done_list.insert(0, 1)
+                reward_i = 1
             else:
                 done_list.insert(0, 0)
+                reward_i = 0
             discounted_reward = reward_i + self.gamma * discounted_reward
             reward_list.insert(0, discounted_reward)
             state_next_list.insert(0, state_next_i)
diff --git a/run.py b/run.py
index 8a96775..667b7c4 100644
--- a/run.py
+++ b/run.py
@@ -47,7 +47,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
 # Print per-step logs
 VERBOSE = True
 USE_FAST_TREEOBS = True
-USE_PPO_AGENT = False
+USE_PPO_AGENT = True
 
 # Checkpoint to use (remember to push it!)
 checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
@@ -57,15 +57,16 @@ checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786
 checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857
 checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504
 checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537
+checkpoint = "./checkpoints/201212190452-6500.pth" # PPO: 13.944402986414723
 
-EPSILON = 0.01
+EPSILON = 0.0
 
 # Use last action cache
 USE_ACTION_CACHE = False
 USE_DEAD_LOCK_AVOIDANCE_AGENT = False  # 21.54485505223213
 
 # Observation parameters (must match training parameters!)
-observation_tree_depth = 1
+observation_tree_depth = 2
 observation_radius = 10
 observation_max_path_depth = 30
 
-- 
GitLab