From f4fca1d59e9df6ccf1b21752ce33419d89753642 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Mon, 30 Nov 2020 15:11:27 +0100 Subject: [PATCH] clean up code - simplified --- reinforcement_learning/multi_agent_training.py | 10 +++++----- run.py | 9 +++++++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 06c35b2..6944eab 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -208,7 +208,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Double Dueling DQN policy policy = DDDQNPolicy(state_size, action_size, train_params) - if False: + if True: policy = PPOAgent(state_size, action_size, n_agents) # Load existing policy if train_params.load_policy is not "": @@ -546,10 +546,10 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, + parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=10000, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) @@ -573,7 +573,7 @@ if __name__ == "__main__": parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true') - parser.add_argument("--max_depth", help="max depth", default=1, type=int) + parser.add_argument("--max_depth", help="max depth", default=2, type=int) training_params = parser.parse_args() env_params = [ diff --git a/run.py b/run.py index 33bc56c..a5e91a3 100644 --- a/run.py +++ b/run.py @@ -30,6 +30,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException +from reinforcement_learning.ppo.ppo_agent import PPOAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked from utils.fast_tree_obs import FastTreeObs @@ -46,12 +47,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy # Print per-step logs VERBOSE = True USE_FAST_TREEOBS = True +USE_PPO_AGENT = True # Checkpoint to use (remember to push it!) checkpoint = "./checkpoints/201124171810-7800.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 # checkpoint = "./checkpoints/201126150143-5200.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 # checkpoint = "./checkpoints/201126160144-2000.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 checkpoint = "./checkpoints/201127160352-2000.pth" +checkpoint = "./checkpoints/201130083154-2000.pth" EPSILON = 0.005 @@ -99,8 +102,10 @@ else: action_size = 5 # Creates the policy. No GPU on evaluation server. -policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) -# policy = PPOAgent(state_size, action_size, 10) +if not USE_PPO_AGENT: + policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) +else: + policy = PPOAgent(state_size, action_size, 10) policy.load(checkpoint) ##################################################################### -- GitLab