From 0a273a9447d151e16f0fec66877364f3104f90f0 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 22 Dec 2020 09:40:09 +0100
Subject: [PATCH] Policy updated

---
 ... deadlockavoidance_with_decision_agent.py} |  2 +-
 .../multi_agent_training.py                   | 23 ++---
 .../multi_decision_agent.py                   | 90 +++++++++++++++++++
 reinforcement_learning/policy.py              | 11 +++
 run.py                                        |  4 +-
 utils/dead_lock_avoidance_agent.py            |  5 --
 utils/deadlock_check.py                       |  9 ++
 7 files changed, 126 insertions(+), 18 deletions(-)
 rename reinforcement_learning/{ppo_deadlockavoidance_agent.py => deadlockavoidance_with_decision_agent.py} (98%)
 create mode 100644 reinforcement_learning/multi_decision_agent.py

diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/deadlockavoidance_with_decision_agent.py
similarity index 98%
rename from reinforcement_learning/ppo_deadlockavoidance_agent.py
rename to reinforcement_learning/deadlockavoidance_with_decision_agent.py
index f891748..a1726f8 100644
--- a/reinforcement_learning/ppo_deadlockavoidance_agent.py
+++ b/reinforcement_learning/deadlockavoidance_with_decision_agent.py
@@ -7,7 +7,7 @@ from utils.agent_can_choose_helper import AgentCanChooseHelper
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 
 
-class MultiDecisionAgent(HybridPolicy):
+class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
 
     def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
         self.env = env
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 4c06133..2e22854 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -9,11 +9,10 @@ from pprint import pprint
 
 import numpy as np
 import psutil
-from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
+from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
@@ -21,11 +20,10 @@ from torch.utils.tensorboard import SummaryWriter
 
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 from reinforcement_learning.ppo_agent import PPOPolicy
-from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent
-from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action, \
-    map_rail_env_action
+from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
+from reinforcement_learning.multi_decision_agent import MultiDecisionAgent
+from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-from utils.deadlock_check import get_agent_positions, check_for_deadlock
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -172,13 +170,18 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     completion_window = deque(maxlen=checkpoint_interval)
 
     # Double Dueling DQN policy
-    policy = DDDQNPolicy(state_size, get_action_size(), train_params)
-    if True:
+    policy = None
+    if False:
+        policy = DDDQNPolicy(state_size, get_action_size(), train_params)
+    if False:
         policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
     if False:
         policy = DeadLockAvoidanceAgent(train_env, get_action_size())
     if False:
-        policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy)
+        inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
+        policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
+    if True:
+        policy = MultiDecisionAgent(state_size, get_action_size(), train_params)
 
     # Load existing policy
     if train_params.load_policy is not "":
@@ -227,7 +230,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Reset environment
         reset_timer.start()
-        number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 500)))
+        number_of_agents = n_agents  # int(min(n_agents, 1 + np.floor(episode_idx / 500)))
         train_env_params.n_agents = episode_idx % number_of_agents + 1
 
         train_env = create_rail_env(train_env_params, tree_observation)
diff --git a/reinforcement_learning/multi_decision_agent.py b/reinforcement_learning/multi_decision_agent.py
new file mode 100644
index 0000000..13d1874
--- /dev/null
+++ b/reinforcement_learning/multi_decision_agent.py
@@ -0,0 +1,90 @@
+from flatland.envs.rail_env import RailEnv
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from reinforcement_learning.policy import LearningPolicy, DummyMemory
+from reinforcement_learning.ppo_agent import PPOPolicy
+
+
+class MultiDecisionAgent(LearningPolicy):
+
+    def __init__(self, state_size, action_size, in_parameters=None):
+        print(">> MultiDecisionAgent")
+        super(MultiDecisionAgent, self).__init__()
+        self.state_size = state_size
+        self.action_size = action_size
+        self.in_parameters = in_parameters
+        self.memory = DummyMemory()
+        self.loss = 0
+
+        self.ppo_policy = PPOPolicy(state_size, action_size, use_replay_buffer=False, in_parameters=in_parameters)
+        self.dddqn_policy = DDDQNPolicy(state_size, action_size, in_parameters)
+        self.policy_selector = PPOPolicy(state_size, 2)
+
+
+    def step(self, handle, state, action, reward, next_state, done):
+        select = self.policy_selector.act(handle, state, 0.0)
+        self.ppo_policy.step(handle, state, action, reward, next_state, done)
+        self.dddqn_policy.step(handle, state, action, reward, next_state, done)
+        self.policy_selector.step(handle, state, select, reward, next_state, done)
+
+    def act(self, handle, state, eps=0.):
+        select = self.policy_selector.act(handle, state, eps)
+        if select == 0:
+            return self.dddqn_policy.act(handle, state, eps)
+        return self.policy_selector.act(handle, state, eps)
+
+    def save(self, filename):
+        self.ppo_policy.save(filename)
+        self.dddqn_policy.save(filename)
+        self.policy_selector.save(filename)
+
+    def load(self, filename):
+        self.ppo_policy.load(filename)
+        self.dddqn_policy.load(filename)
+        self.policy_selector.load(filename)
+
+    def start_step(self, train):
+        self.ppo_policy.start_step(train)
+        self.dddqn_policy.start_step(train)
+        self.policy_selector.start_step(train)
+
+    def end_step(self, train):
+        self.ppo_policy.end_step(train)
+        self.dddqn_policy.end_step(train)
+        self.policy_selector.end_step(train)
+
+    def start_episode(self, train):
+        self.ppo_policy.start_episode(train)
+        self.dddqn_policy.start_episode(train)
+        self.policy_selector.start_episode(train)
+
+    def end_episode(self, train):
+        self.ppo_policy.end_episode(train)
+        self.dddqn_policy.end_episode(train)
+        self.policy_selector.end_episode(train)
+
+    def load_replay_buffer(self, filename):
+        self.ppo_policy.load_replay_buffer(filename)
+        self.dddqn_policy.load_replay_buffer(filename)
+        self.policy_selector.load_replay_buffer(filename)
+
+    def test(self):
+        self.ppo_policy.test()
+        self.dddqn_policy.test()
+        self.policy_selector.test()
+
+    def reset(self, env: RailEnv):
+        self.ppo_policy.reset(env)
+        self.dddqn_policy.reset(env)
+        self.policy_selector.reset(env)
+
+    def clone(self):
+        multi_descision_agent = MultiDecisionAgent(
+            self.state_size,
+            self.action_size,
+            self.in_parameters
+        )
+        multi_descision_agent.ppo_policy = self.ppo_policy.clone()
+        multi_descision_agent.dddqn_policy = self.dddqn_policy.clone()
+        multi_descision_agent.policy_selector = self.policy_selector.clone()
+        return multi_descision_agent
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index 9b883d1..fe28cbc 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -1,6 +1,14 @@
 from flatland.envs.rail_env import RailEnv
 
 
+class DummyMemory:
+    def __init__(self):
+        self.memory = []
+
+    def __len__(self):
+        return 0
+
+
 class Policy:
     def step(self, handle, state, action, reward, next_state, done):
         raise NotImplementedError
@@ -38,14 +46,17 @@ class Policy:
     def clone(self):
         return self
 
+
 class HeuristicPolicy(Policy):
     def __init__(self):
         super(HeuristicPolicy).__init__()
 
+
 class LearningPolicy(Policy):
     def __init__(self):
         super(LearningPolicy).__init__()
 
+
 class HybridPolicy(Policy):
     def __init__(self):
         super(HybridPolicy).__init__()
diff --git a/run.py b/run.py
index 6578fee..8d97053 100644
--- a/run.py
+++ b/run.py
@@ -31,7 +31,7 @@ from flatland.evaluators.client import FlatlandRemoteClient
 from flatland.evaluators.client import TimeoutException
 
 from reinforcement_learning.ppo_agent import PPOPolicy
-from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent
+from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
 from utils.agent_action_config import get_action_size, map_actions
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 from utils.deadlock_check import check_if_all_blocked
@@ -147,7 +147,7 @@ while True:
 
     policy = trained_policy
     if USE_MULTI_DECISION_AGENT:
-        policy = MultiDecisionAgent(local_env, state_size, action_size, trained_policy)
+        policy = DeadLockAvoidanceWithDecisionAgent(local_env, state_size, action_size, trained_policy)
     policy.reset(local_env)
     observation = tree_observation.get_many(list(range(nb_agents)))
 
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 1f0030c..ed3a3f7 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -67,12 +67,7 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
         self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
 
 
-class DummyMemory:
-    def __init__(self):
-        self.memory = []
 
-    def __len__(self):
-        return 0
 
 
 class DeadLockAvoidanceAgent(HeuristicPolicy):
diff --git a/utils/deadlock_check.py b/utils/deadlock_check.py
index d787c8c..4df6731 100644
--- a/utils/deadlock_check.py
+++ b/utils/deadlock_check.py
@@ -17,6 +17,15 @@ def get_agent_positions(env):
     return agent_positions
 
 
+def get_agent_targets(env):
+    agent_targets = []
+    for agent_handle in env.get_agent_handles():
+        agent = env.agents[agent_handle]
+        if agent.status == RailAgentStatus.ACTIVE:
+            agent_targets.append(agent.target)
+    return agent_targets
+
+
 def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None):
     agent = env.agents[handle]
     if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
-- 
GitLab