From ed050703921efdab033a5898d2993be2a87a0e8c Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Sat, 19 Dec 2020 13:01:22 +0100
Subject: [PATCH] refactored

---
 reinforcement_learning/dddqn_policy.py               |  4 ++--
 reinforcement_learning/multi_agent_training.py       |  6 +++---
 reinforcement_learning/policy.py                     | 12 ++++++++++++
 reinforcement_learning/ppo_agent.py                  |  6 +++---
 .../ppo_deadlockavoidance_agent.py                   |  6 +++---
 utils/dead_lock_avoidance_agent.py                   |  4 ++--
 6 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 9864ca6..864c6a7 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -9,11 +9,11 @@ import torch.nn.functional as F
 import torch.optim as optim
 
 from reinforcement_learning.model import DuelingQNetwork
-from reinforcement_learning.policy import Policy
+from reinforcement_learning.policy import Policy, LearningPolicy
 from reinforcement_learning.replay_buffer import ReplayBuffer
 
 
-class DDDQNPolicy(Policy):
+class DDDQNPolicy(LearningPolicy):
     """Dueling Double DQN policy"""
 
     def __init__(self, state_size, action_size, in_parameters, evaluation_mode=False):
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index cce7ecc..b219ace 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -519,11 +519,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 if __name__ == "__main__":
     parser = ArgumentParser()
     parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12000, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=3,
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
                         type=int)
-    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2,
                         type=int)
-    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
     parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float)
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index 5b118ae..9b883d1 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -37,3 +37,15 @@ class Policy:
 
     def clone(self):
         return self
+
+class HeuristicPolicy(Policy):
+    def __init__(self):
+        super(HeuristicPolicy).__init__()
+
+class LearningPolicy(Policy):
+    def __init__(self):
+        super(LearningPolicy).__init__()
+
+class HybridPolicy(Policy):
+    def __init__(self):
+        super(HybridPolicy).__init__()
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
index 5c0fc08..51f0f71 100644
--- a/reinforcement_learning/ppo_agent.py
+++ b/reinforcement_learning/ppo_agent.py
@@ -8,7 +8,7 @@ import torch.optim as optim
 from torch.distributions import Categorical
 
 # Hyperparameters
-from reinforcement_learning.policy import Policy
+from reinforcement_learning.policy import LearningPolicy
 from reinforcement_learning.replay_buffer import ReplayBuffer
 
 device = torch.device("cpu")  # "cuda:0" if torch.cuda.is_available() else "cpu")
@@ -92,10 +92,10 @@ class ActorCriticModel(nn.Module):
     def load(self, filename):
         print("load policy from file", filename)
         self.actor = self._load(self.actor, filename + ".actor")
-        self.critic = self._load(self.critic, filename + ".critic")
+        self.critic = self._load(self.critic, filename + ".value")
 
 
-class PPOPolicy(Policy):
+class PPOPolicy(LearningPolicy):
     def __init__(self, state_size, action_size):
         print(">> PPOPolicy")
         super(PPOPolicy, self).__init__()
diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py
index 6e8880c..f891748 100644
--- a/reinforcement_learning/ppo_deadlockavoidance_agent.py
+++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py
@@ -1,13 +1,13 @@
 from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 
-from reinforcement_learning.policy import Policy
+from reinforcement_learning.policy import HybridPolicy
 from utils.agent_action_config import map_rail_env_action
 from utils.agent_can_choose_helper import AgentCanChooseHelper
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 
 
-class MultiDecisionAgent(Policy):
+class MultiDecisionAgent(HybridPolicy):
 
     def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
         self.env = env
@@ -33,7 +33,7 @@ class MultiDecisionAgent(Policy):
         if agent.status < RailAgentStatus.DONE:
             agents_on_switch, agents_near_to_switch, _, _ = \
                 self.agent_can_choose_helper.check_agent_decision(position, direction)
-            if agents_on_switch or agents_near_to_switch:
+            if agents_on_switch:
                 return self.learning_agent.act(handle, state, eps)
             else:
                 act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index cad3b74..1f0030c 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -6,7 +6,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder
 from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
 
-from reinforcement_learning.policy import Policy
+from reinforcement_learning.policy import HeuristicPolicy
 from utils.agent_action_config import map_rail_env_action
 from utils.shortest_distance_walker import ShortestDistanceWalker
 
@@ -75,7 +75,7 @@ class DummyMemory:
         return 0
 
 
-class DeadLockAvoidanceAgent(Policy):
+class DeadLockAvoidanceAgent(HeuristicPolicy):
     def __init__(self, env: RailEnv, action_size, show_debug_plot=False):
         self.env = env
         self.memory = DummyMemory()
-- 
GitLab