From f4fca1d59e9df6ccf1b21752ce33419d89753642 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Mon, 30 Nov 2020 15:11:27 +0100
Subject: [PATCH] clean up code - simplified

---
 reinforcement_learning/multi_agent_training.py | 10 +++++-----
 run.py                                         |  9 +++++++--
 2 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 06c35b2..6944eab 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -208,7 +208,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, action_size, train_params)
-    if False:
+    if True:
         policy = PPOAgent(state_size, action_size, n_agents)
     # Load existing policy
     if train_params.load_policy is not "":
@@ -546,10 +546,10 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=10000, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0,
                         type=int)
-    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
                         type=int)
     parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
@@ -573,7 +573,7 @@ if __name__ == "__main__":
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
     parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
                         action='store_true')
-    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
+    parser.add_argument("--max_depth", help="max depth", default=2, type=int)
 
     training_params = parser.parse_args()
     env_params = [
diff --git a/run.py b/run.py
index 33bc56c..a5e91a3 100644
--- a/run.py
+++ b/run.py
@@ -30,6 +30,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.evaluators.client import FlatlandRemoteClient
 from flatland.evaluators.client import TimeoutException
 
+from reinforcement_learning.ppo.ppo_agent import PPOAgent
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 from utils.deadlock_check import check_if_all_blocked
 from utils.fast_tree_obs import FastTreeObs
@@ -46,12 +47,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
 # Print per-step logs
 VERBOSE = True
 USE_FAST_TREEOBS = True
+USE_PPO_AGENT = True
 
 # Checkpoint to use (remember to push it!)
 checkpoint = "./checkpoints/201124171810-7800.pth"  # 18.249244799876152 DEPTH=2 AGENTS=10
 # checkpoint = "./checkpoints/201126150143-5200.pth"  # 18.249244799876152 DEPTH=2 AGENTS=10
 # checkpoint = "./checkpoints/201126160144-2000.pth"  # 18.249244799876152 DEPTH=2 AGENTS=10
 checkpoint = "./checkpoints/201127160352-2000.pth"
+checkpoint = "./checkpoints/201130083154-2000.pth"
 
 EPSILON = 0.005
 
@@ -99,8 +102,10 @@ else:
 action_size = 5
 
 # Creates the policy. No GPU on evaluation server.
-policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
-# policy = PPOAgent(state_size, action_size, 10)
+if not USE_PPO_AGENT:
+    policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
+else:
+    policy = PPOAgent(state_size, action_size, 10)
 policy.load(checkpoint)
 
 #####################################################################
-- 
GitLab