diff --git a/run.py b/run.py index 8d97053a8ec304d0e5696776e0dfcf9cd4148e91..af654ba8f8a918510c6bd99ac6b4199732ce2ed0 100644 --- a/run.py +++ b/run.py @@ -30,9 +30,11 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException -from reinforcement_learning.ppo_agent import PPOPolicy +from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent -from utils.agent_action_config import get_action_size, map_actions +from reinforcement_learning.multi_decision_agent import MultiDecisionAgent +from reinforcement_learning.ppo_agent import PPOPolicy +from utils.agent_action_config import get_action_size, map_actions, set_action_size_full, set_action_size_reduced from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked from utils.fast_tree_obs import FastTreeObs @@ -41,8 +43,6 @@ from utils.observation_utils import normalize_observation base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) -from reinforcement_learning.dddqn_policy import DDDQNPolicy - #################################################### # EVALUATION PARAMETERS @@ -52,18 +52,13 @@ USE_FAST_TREEOBS = True USE_PPO_AGENT = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201219090514-8600.pth" # -# checkpoint = "./checkpoints/201215212134-12000.pth" # -checkpoint = "./checkpoints/201220171629-12000.pth" # DDDQN - EPSILON: 0.0 - 13.940848323912533 -checkpoint = "./checkpoints/201220203236-12000.pth" # PPO - EPSILON: 0.0 - 13.660942453931114 -checkpoint = "./checkpoints/201220214325-12000.pth" # PPO - EPSILON: 0.0 - 13.463600936043 - -EPSILON = 0.0 +USE_ACTION_SIZE_FULL = False +load_policy = "DeadLockAvoidanceWithDecision" +checkpoint = "./checkpoints/210118115616-5000.pth" +EPSILON = 0.002 # Use last action cache USE_ACTION_CACHE = False -USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213 -USE_MULTI_DECISION_AGENT = False # Observation parameters (must match training parameters!) observation_tree_depth = 2 @@ -102,15 +97,6 @@ else: n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) state_size = n_features_per_node * n_nodes -action_size = get_action_size() - -# Creates the policy. No GPU on evaluation server. -if not USE_PPO_AGENT: - trained_policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) -else: - trained_policy = PPOPolicy(state_size, action_size) -trained_policy.load(checkpoint) - ##################################################################### # Main evaluation loop ##################################################################### @@ -145,9 +131,31 @@ while True: tree_observation.set_env(local_env) tree_observation.reset() - policy = trained_policy - if USE_MULTI_DECISION_AGENT: - policy = DeadLockAvoidanceWithDecisionAgent(local_env, state_size, action_size, trained_policy) + # Creates the policy. No GPU on evaluation server. + if USE_ACTION_SIZE_FULL: + set_action_size_full() + else: + set_action_size_reduced() + + if load_policy == "DDDQN": + policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True) + elif load_policy == "PPO": + policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, + in_parameters=Namespace(**{'use_gpu': False})) + elif load_policy == "DeadLockAvoidance": + policy = DeadLockAvoidanceAgent(local_env, get_action_size(), enable_eps=False) + elif load_policy == "DeadLockAvoidanceWithDecision": + # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) + inter_policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True) + policy = DeadLockAvoidanceWithDecisionAgent(local_env, state_size, get_action_size(), inter_policy) + elif load_policy == "MultiDecision": + policy = MultiDecisionAgent(state_size, get_action_size(), Namespace(**{'use_gpu': False})) + else: + policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, + in_parameters=Namespace(**{'use_gpu': False})) + + policy.load(checkpoint) + policy.reset(local_env) observation = tree_observation.get_many(list(range(nb_agents))) @@ -168,9 +176,6 @@ while True: agent_last_action = {} nb_hit = 0 - if USE_DEAD_LOCK_AVOIDANCE_AGENT: - policy = DeadLockAvoidanceAgent(local_env, action_size) - policy.start_episode(train=False) while True: try: @@ -185,14 +190,7 @@ while True: time_start = time.time() action_dict = {} policy.start_step(train=False) - if USE_DEAD_LOCK_AVOIDANCE_AGENT: - observation = np.zeros((local_env.get_num_agents(), 2)) for agent_handle in range(nb_agents): - - if USE_DEAD_LOCK_AVOIDANCE_AGENT: - observation[agent_handle][0] = agent_handle - observation[agent_handle][1] = steps - if info['action_required'][agent_handle]: if agent_handle in agent_last_obs and np.all( agent_last_obs[agent_handle] == observation[agent_handle]): diff --git a/runs/Jan14_10-56-32_K57261_PPO_reduced/events.out.tfevents.1610618195.K57261.15412.0 b/runs_bench/Jan14_10-56-32_K57261_PPO_reduced/events.out.tfevents.1610618195.K57261.15412.0 similarity index 100% rename from runs/Jan14_10-56-32_K57261_PPO_reduced/events.out.tfevents.1610618195.K57261.15412.0 rename to runs_bench/Jan14_10-56-32_K57261_PPO_reduced/events.out.tfevents.1610618195.K57261.15412.0 diff --git a/runs/Jan18_09-32-17_K57261_DDDQN_reduced/events.out.tfevents.1610958740.K57261.6608.0 b/runs_bench/Jan18_09-32-17_K57261_DDDQN_reduced/events.out.tfevents.1610958740.K57261.6608.0 similarity index 100% rename from runs/Jan18_09-32-17_K57261_DDDQN_reduced/events.out.tfevents.1610958740.K57261.6608.0 rename to runs_bench/Jan18_09-32-17_K57261_DDDQN_reduced/events.out.tfevents.1610958740.K57261.6608.0 diff --git a/runs/Jan18_09-34-10_K57261_DeadLockAvoidance_EPS_reduced/events.out.tfevents.1610958853.K57261.10660.0 b/runs_bench/Jan18_09-34-10_K57261_DeadLockAvoidance_EPS_reduced/events.out.tfevents.1610958853.K57261.10660.0 similarity index 100% rename from runs/Jan18_09-34-10_K57261_DeadLockAvoidance_EPS_reduced/events.out.tfevents.1610958853.K57261.10660.0 rename to runs_bench/Jan18_09-34-10_K57261_DeadLockAvoidance_EPS_reduced/events.out.tfevents.1610958853.K57261.10660.0 diff --git a/runs/Jan18_11-47-54_K57261_DeadLockAvoidance_reduced/events.out.tfevents.1610966876.K57261.4332.0 b/runs_bench/Jan18_11-47-54_K57261_DeadLockAvoidance_reduced/events.out.tfevents.1610966876.K57261.4332.0 similarity index 100% rename from runs/Jan18_11-47-54_K57261_DeadLockAvoidance_reduced/events.out.tfevents.1610966876.K57261.4332.0 rename to runs_bench/Jan18_11-47-54_K57261_DeadLockAvoidance_reduced/events.out.tfevents.1610966876.K57261.4332.0 diff --git a/runs/Jan18_11-56-16_K57261_DeadLockAvoidanceWithDecision_reduced/events.out.tfevents.1610967379.K57261.14680.0 b/runs_bench/Jan18_11-56-16_K57261_DeadLockAvoidanceWithDecision_reduced/events.out.tfevents.1610967379.K57261.14680.0 similarity index 100% rename from runs/Jan18_11-56-16_K57261_DeadLockAvoidanceWithDecision_reduced/events.out.tfevents.1610967379.K57261.14680.0 rename to runs_bench/Jan18_11-56-16_K57261_DeadLockAvoidanceWithDecision_reduced/events.out.tfevents.1610967379.K57261.14680.0 diff --git a/runs/Jan18_13-46-59_K57261_MultiDecisionAgent_reduced/events.out.tfevents.1610974021.K57261.12972.0 b/runs_bench/Jan18_13-46-59_K57261_MultiDecisionAgent_reduced/events.out.tfevents.1610974021.K57261.12972.0 similarity index 100% rename from runs/Jan18_13-46-59_K57261_MultiDecisionAgent_reduced/events.out.tfevents.1610974021.K57261.12972.0 rename to runs_bench/Jan18_13-46-59_K57261_MultiDecisionAgent_reduced/events.out.tfevents.1610974021.K57261.12972.0 diff --git a/runs/Jan18_14-53-57_K57261_PPO_full/events.out.tfevents.1610978039.K57261.484.0 b/runs_bench/Jan18_14-53-57_K57261_PPO_full/events.out.tfevents.1610978039.K57261.484.0 similarity index 100% rename from runs/Jan18_14-53-57_K57261_PPO_full/events.out.tfevents.1610978039.K57261.484.0 rename to runs_bench/Jan18_14-53-57_K57261_PPO_full/events.out.tfevents.1610978039.K57261.484.0 diff --git a/runs/Jan18_14-57-56_K57261_DDDQN_full/events.out.tfevents.1610978281.K57261.19984.0 b/runs_bench/Jan18_14-57-56_K57261_DDDQN_full/events.out.tfevents.1610978281.K57261.19984.0 similarity index 100% rename from runs/Jan18_14-57-56_K57261_DDDQN_full/events.out.tfevents.1610978281.K57261.19984.0 rename to runs_bench/Jan18_14-57-56_K57261_DDDQN_full/events.out.tfevents.1610978281.K57261.19984.0 diff --git a/runs/Jan18_16-05-23_K57261_DeadLockAvoidance_EPS_full/events.out.tfevents.1610982327.K57261.6264.0 b/runs_bench/Jan18_16-05-23_K57261_DeadLockAvoidance_EPS_full/events.out.tfevents.1610982327.K57261.6264.0 similarity index 100% rename from runs/Jan18_16-05-23_K57261_DeadLockAvoidance_EPS_full/events.out.tfevents.1610982327.K57261.6264.0 rename to runs_bench/Jan18_16-05-23_K57261_DeadLockAvoidance_EPS_full/events.out.tfevents.1610982327.K57261.6264.0 diff --git a/runs/Jan18_16-14-19_K57261_DeadLockAvoidance_full/events.out.tfevents.1610982862.K57261.14612.0 b/runs_bench/Jan18_16-14-19_K57261_DeadLockAvoidance_full/events.out.tfevents.1610982862.K57261.14612.0 similarity index 100% rename from runs/Jan18_16-14-19_K57261_DeadLockAvoidance_full/events.out.tfevents.1610982862.K57261.14612.0 rename to runs_bench/Jan18_16-14-19_K57261_DeadLockAvoidance_full/events.out.tfevents.1610982862.K57261.14612.0 diff --git a/runs/Jan18_16-43-41_K57261_DeadLockAvoidanceWithDecision_full/events.out.tfevents.1610984623.K57261.17628.0 b/runs_bench/Jan18_16-43-41_K57261_DeadLockAvoidanceWithDecision_full/events.out.tfevents.1610984623.K57261.17628.0 similarity index 100% rename from runs/Jan18_16-43-41_K57261_DeadLockAvoidanceWithDecision_full/events.out.tfevents.1610984623.K57261.17628.0 rename to runs_bench/Jan18_16-43-41_K57261_DeadLockAvoidanceWithDecision_full/events.out.tfevents.1610984623.K57261.17628.0 diff --git a/runs/Jan18_16-45-04_K57261_MultiDecision_full/events.out.tfevents.1610984709.K57261.1796.0 b/runs_bench/Jan18_16-45-04_K57261_MultiDecision_full/events.out.tfevents.1610984709.K57261.1796.0 similarity index 100% rename from runs/Jan18_16-45-04_K57261_MultiDecision_full/events.out.tfevents.1610984709.K57261.1796.0 rename to runs_bench/Jan18_16-45-04_K57261_MultiDecision_full/events.out.tfevents.1610984709.K57261.1796.0 diff --git a/utils/agent_action_config.py b/utils/agent_action_config.py index 9c2af58404e79b2b16430eed2cc71978420e987c..29750c9222965a044e4f447b7cdb8a97517b9953 100644 --- a/utils/agent_action_config.py +++ b/utils/agent_action_config.py @@ -11,16 +11,19 @@ def get_flatland_full_action_size(): def set_action_size_full(): + global _agent_action_config_action_size # The agents (DDDQN, PPO, ... ) have this actions space _agent_action_config_action_size = 5 def set_action_size_reduced(): + global _agent_action_config_action_size # The agents (DDDQN, PPO, ... ) have this actions space _agent_action_config_action_size = 4 def get_action_size(): + global _agent_action_config_action_size # The agents (DDDQN, PPO, ... ) have this actions space return _agent_action_config_action_size