Skip to content
Snippets Groups Projects
Commit 97104dee authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

clear

parent e4443c95
No related branches found
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ from pathlib import Path ...@@ -25,6 +25,7 @@ from pathlib import Path
import numpy as np import numpy as np
from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import FlatlandRemoteClient
...@@ -34,7 +35,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy ...@@ -34,7 +35,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
from reinforcement_learning.multi_decision_agent import MultiDecisionAgent from reinforcement_learning.multi_decision_agent import MultiDecisionAgent
from reinforcement_learning.ppo_agent import PPOPolicy from reinforcement_learning.ppo_agent import PPOPolicy
from utils.agent_action_config import get_action_size, map_actions, set_action_size_full, set_action_size_reduced from utils.agent_action_config import get_action_size, map_actions, set_action_size_reduced, set_action_size_full
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import check_if_all_blocked from utils.deadlock_check import check_if_all_blocked
from utils.fast_tree_obs import FastTreeObs from utils.fast_tree_obs import FastTreeObs
...@@ -45,6 +46,7 @@ sys.path.append(str(base_dir)) ...@@ -45,6 +46,7 @@ sys.path.append(str(base_dir))
#################################################### ####################################################
# EVALUATION PARAMETERS # EVALUATION PARAMETERS
set_action_size_full()
# Print per-step logs # Print per-step logs
VERBOSE = True VERBOSE = True
...@@ -52,10 +54,24 @@ USE_FAST_TREEOBS = True ...@@ -52,10 +54,24 @@ USE_FAST_TREEOBS = True
USE_PPO_AGENT = True USE_PPO_AGENT = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
USE_ACTION_SIZE_FULL = False set_action_size_reduced()
load_policy = "DeadLockAvoidanceWithDecision" load_policy = "DeadLockAvoidanceWithDecision"
checkpoint = "./checkpoints/210118115616-5000.pth" checkpoint = "./checkpoints/210119075622-10000.pth" # 22.13346834815911
EPSILON = 0.002 EPSILON = 0.0
# Checkpoint to use (remember to push it!)
set_action_size_reduced()
load_policy = "PPO"
checkpoint = "./checkpoints/210119134958-10000.pth" # 12.18162927750207
EPSILON = 0.0
# Checkpoint to use (remember to push it!)
set_action_size_reduced()
load_policy = "DDDQN"
checkpoint = "./checkpoints/210119171409-10000.pth" # 12.18162927750207
EPSILON = 0.0
load_policy = "DeadLockAvoidance"
# Use last action cache # Use last action cache
USE_ACTION_CACHE = False USE_ACTION_CACHE = False
...@@ -132,16 +148,10 @@ while True: ...@@ -132,16 +148,10 @@ while True:
tree_observation.reset() tree_observation.reset()
# Creates the policy. No GPU on evaluation server. # Creates the policy. No GPU on evaluation server.
if USE_ACTION_SIZE_FULL:
set_action_size_full()
else:
set_action_size_reduced()
if load_policy == "DDDQN": if load_policy == "DDDQN":
policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True) policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True)
elif load_policy == "PPO": elif load_policy == "PPO":
policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, policy = PPOPolicy(state_size, get_action_size())
in_parameters=Namespace(**{'use_gpu': False}))
elif load_policy == "DeadLockAvoidance": elif load_policy == "DeadLockAvoidance":
policy = DeadLockAvoidanceAgent(local_env, get_action_size(), enable_eps=False) policy = DeadLockAvoidanceAgent(local_env, get_action_size(), enable_eps=False)
elif load_policy == "DeadLockAvoidanceWithDecision": elif load_policy == "DeadLockAvoidanceWithDecision":
...@@ -232,7 +242,11 @@ while True: ...@@ -232,7 +242,11 @@ while True:
step_time = time.time() - time_start step_time = time.time() - time_start
time_taken_per_step.append(step_time) time_taken_per_step.append(step_time)
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles()) nb_agents_done = 0
for i_agent, agent in enumerate(local_env.agents):
# manage the boolean flag to check if all agents are indeed done (or done_removed)
if (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]):
nb_agents_done += 1
if VERBOSE or done['__all__']: if VERBOSE or done['__all__']:
print( print(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment