From 97104dee80d5df776dac041f7286f833207caf4b Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Wed, 20 Jan 2021 17:13:13 +0100 Subject: [PATCH] clear --- run.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/run.py b/run.py index af654ba..57f8fdc 100644 --- a/run.py +++ b/run.py @@ -25,6 +25,7 @@ from pathlib import Path import numpy as np from flatland.core.env_observation_builder import DummyObservationBuilder +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.evaluators.client import FlatlandRemoteClient @@ -34,7 +35,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent from reinforcement_learning.multi_decision_agent import MultiDecisionAgent from reinforcement_learning.ppo_agent import PPOPolicy -from utils.agent_action_config import get_action_size, map_actions, set_action_size_full, set_action_size_reduced +from utils.agent_action_config import get_action_size, map_actions, set_action_size_reduced, set_action_size_full from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked from utils.fast_tree_obs import FastTreeObs @@ -45,6 +46,7 @@ sys.path.append(str(base_dir)) #################################################### # EVALUATION PARAMETERS +set_action_size_full() # Print per-step logs VERBOSE = True @@ -52,10 +54,24 @@ USE_FAST_TREEOBS = True USE_PPO_AGENT = True # Checkpoint to use (remember to push it!) -USE_ACTION_SIZE_FULL = False +set_action_size_reduced() load_policy = "DeadLockAvoidanceWithDecision" -checkpoint = "./checkpoints/210118115616-5000.pth" -EPSILON = 0.002 +checkpoint = "./checkpoints/210119075622-10000.pth" # 22.13346834815911 +EPSILON = 0.0 + +# Checkpoint to use (remember to push it!) +set_action_size_reduced() +load_policy = "PPO" +checkpoint = "./checkpoints/210119134958-10000.pth" # 12.18162927750207 +EPSILON = 0.0 + +# Checkpoint to use (remember to push it!) +set_action_size_reduced() +load_policy = "DDDQN" +checkpoint = "./checkpoints/210119171409-10000.pth" # 12.18162927750207 +EPSILON = 0.0 + +load_policy = "DeadLockAvoidance" # Use last action cache USE_ACTION_CACHE = False @@ -132,16 +148,10 @@ while True: tree_observation.reset() # Creates the policy. No GPU on evaluation server. - if USE_ACTION_SIZE_FULL: - set_action_size_full() - else: - set_action_size_reduced() - if load_policy == "DDDQN": policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True) elif load_policy == "PPO": - policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, - in_parameters=Namespace(**{'use_gpu': False})) + policy = PPOPolicy(state_size, get_action_size()) elif load_policy == "DeadLockAvoidance": policy = DeadLockAvoidanceAgent(local_env, get_action_size(), enable_eps=False) elif load_policy == "DeadLockAvoidanceWithDecision": @@ -232,7 +242,11 @@ while True: step_time = time.time() - time_start time_taken_per_step.append(step_time) - nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles()) + nb_agents_done = 0 + for i_agent, agent in enumerate(local_env.agents): + # manage the boolean flag to check if all agents are indeed done (or done_removed) + if (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]): + nb_agents_done += 1 if VERBOSE or done['__all__']: print( -- GitLab