From 97104dee80d5df776dac041f7286f833207caf4b Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Wed, 20 Jan 2021 17:13:13 +0100
Subject: [PATCH] clear

---
 run.py | 38 ++++++++++++++++++++++++++------------
 1 file changed, 26 insertions(+), 12 deletions(-)

diff --git a/run.py b/run.py
index af654ba..57f8fdc 100644
--- a/run.py
+++ b/run.py
@@ -25,6 +25,7 @@ from pathlib import Path
 
 import numpy as np
 from flatland.core.env_observation_builder import DummyObservationBuilder
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.evaluators.client import FlatlandRemoteClient
@@ -34,7 +35,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
 from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
 from reinforcement_learning.multi_decision_agent import MultiDecisionAgent
 from reinforcement_learning.ppo_agent import PPOPolicy
-from utils.agent_action_config import get_action_size, map_actions, set_action_size_full, set_action_size_reduced
+from utils.agent_action_config import get_action_size, map_actions, set_action_size_reduced, set_action_size_full
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 from utils.deadlock_check import check_if_all_blocked
 from utils.fast_tree_obs import FastTreeObs
@@ -45,6 +46,7 @@ sys.path.append(str(base_dir))
 
 ####################################################
 # EVALUATION PARAMETERS
+set_action_size_full()
 
 # Print per-step logs
 VERBOSE = True
@@ -52,10 +54,24 @@ USE_FAST_TREEOBS = True
 USE_PPO_AGENT = True
 
 # Checkpoint to use (remember to push it!)
-USE_ACTION_SIZE_FULL = False
+set_action_size_reduced()
 load_policy = "DeadLockAvoidanceWithDecision"
-checkpoint = "./checkpoints/210118115616-5000.pth"
-EPSILON = 0.002
+checkpoint = "./checkpoints/210119075622-10000.pth"  # 22.13346834815911
+EPSILON = 0.0
+
+# Checkpoint to use (remember to push it!)
+set_action_size_reduced()
+load_policy = "PPO"
+checkpoint = "./checkpoints/210119134958-10000.pth"  # 12.18162927750207
+EPSILON = 0.0
+
+# Checkpoint to use (remember to push it!)
+set_action_size_reduced()
+load_policy = "DDDQN"
+checkpoint = "./checkpoints/210119171409-10000.pth"  # 12.18162927750207
+EPSILON = 0.0
+
+load_policy = "DeadLockAvoidance"
 
 # Use last action cache
 USE_ACTION_CACHE = False
@@ -132,16 +148,10 @@ while True:
     tree_observation.reset()
 
     # Creates the policy. No GPU on evaluation server.
-    if USE_ACTION_SIZE_FULL:
-        set_action_size_full()
-    else:
-        set_action_size_reduced()
-
     if load_policy == "DDDQN":
         policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True)
     elif load_policy == "PPO":
-        policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False,
-                           in_parameters=Namespace(**{'use_gpu': False}))
+        policy = PPOPolicy(state_size, get_action_size())
     elif load_policy == "DeadLockAvoidance":
         policy = DeadLockAvoidanceAgent(local_env, get_action_size(), enable_eps=False)
     elif load_policy == "DeadLockAvoidanceWithDecision":
@@ -232,7 +242,11 @@ while True:
                 step_time = time.time() - time_start
                 time_taken_per_step.append(step_time)
 
-            nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
+            nb_agents_done = 0
+            for i_agent, agent in enumerate(local_env.agents):
+                # manage the boolean flag to check if all agents are indeed done (or done_removed)
+                if (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]):
+                    nb_agents_done += 1
 
             if VERBOSE or done['__all__']:
                 print(
-- 
GitLab