Skip to content
Snippets Groups Projects
Commit f4fca1d5 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

clean up code - simplified

parent c12f806e
No related branches found
No related tags found
No related merge requests found
...@@ -208,7 +208,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -208,7 +208,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Double Dueling DQN policy # Double Dueling DQN policy
policy = DDDQNPolicy(state_size, action_size, train_params) policy = DDDQNPolicy(state_size, action_size, train_params)
if False: if True:
policy = PPOAgent(state_size, action_size, n_agents) policy = PPOAgent(state_size, action_size, n_agents)
# Load existing policy # Load existing policy
if train_params.load_policy is not "": if train_params.load_policy is not "":
...@@ -546,10 +546,10 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): ...@@ -546,10 +546,10 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int) parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=10000, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0,
type=int) type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int) type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
...@@ -573,7 +573,7 @@ if __name__ == "__main__": ...@@ -573,7 +573,7 @@ if __name__ == "__main__":
parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
action='store_true') action='store_true')
parser.add_argument("--max_depth", help="max depth", default=1, type=int) parser.add_argument("--max_depth", help="max depth", default=2, type=int)
training_params = parser.parse_args() training_params = parser.parse_args()
env_params = [ env_params = [
......
...@@ -30,6 +30,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv ...@@ -30,6 +30,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import FlatlandRemoteClient
from flatland.evaluators.client import TimeoutException from flatland.evaluators.client import TimeoutException
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import check_if_all_blocked from utils.deadlock_check import check_if_all_blocked
from utils.fast_tree_obs import FastTreeObs from utils.fast_tree_obs import FastTreeObs
...@@ -46,12 +47,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy ...@@ -46,12 +47,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
# Print per-step logs # Print per-step logs
VERBOSE = True VERBOSE = True
USE_FAST_TREEOBS = True USE_FAST_TREEOBS = True
USE_PPO_AGENT = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201124171810-7800.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 checkpoint = "./checkpoints/201124171810-7800.pth" # 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126150143-5200.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 # checkpoint = "./checkpoints/201126150143-5200.pth" # 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126160144-2000.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 # checkpoint = "./checkpoints/201126160144-2000.pth" # 18.249244799876152 DEPTH=2 AGENTS=10
checkpoint = "./checkpoints/201127160352-2000.pth" checkpoint = "./checkpoints/201127160352-2000.pth"
checkpoint = "./checkpoints/201130083154-2000.pth"
EPSILON = 0.005 EPSILON = 0.005
...@@ -99,8 +102,10 @@ else: ...@@ -99,8 +102,10 @@ else:
action_size = 5 action_size = 5
# Creates the policy. No GPU on evaluation server. # Creates the policy. No GPU on evaluation server.
policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) if not USE_PPO_AGENT:
# policy = PPOAgent(state_size, action_size, 10) policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
else:
policy = PPOAgent(state_size, action_size, 10)
policy.load(checkpoint) policy.load(checkpoint)
##################################################################### #####################################################################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment