From e28c57e5a9c17aba9415d3f80d7cb5e6a998d9c9 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 5 Jan 2021 10:25:35 +0100
Subject: [PATCH] Policy updated

---
 .../deadlockavoidance_with_decision_agent.py  | 36 +++++-----
 .../multi_agent_training.py                   |  3 +-
 .../multi_decision_agent.py                   |  2 +-
 utils/fast_tree_obs.py                        | 72 ++++++++-----------
 4 files changed, 51 insertions(+), 62 deletions(-)

diff --git a/reinforcement_learning/deadlockavoidance_with_decision_agent.py b/reinforcement_learning/deadlockavoidance_with_decision_agent.py
index e9a6f8e..550e73e 100644
--- a/reinforcement_learning/deadlockavoidance_with_decision_agent.py
+++ b/reinforcement_learning/deadlockavoidance_with_decision_agent.py
@@ -2,8 +2,8 @@ from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 
 from reinforcement_learning.policy import HybridPolicy
+from reinforcement_learning.ppo_agent import PPOPolicy
 from utils.agent_action_config import map_rail_env_action
-from utils.agent_can_choose_helper import AgentCanChooseHelper
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 
 
@@ -17,69 +17,69 @@ class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
         self.action_size = action_size
         self.learning_agent = learning_agent
         self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False)
-        self.agent_can_choose_helper = AgentCanChooseHelper()
+        self.policy_selector = PPOPolicy(state_size, 2)
+
         self.memory = self.learning_agent.memory
         self.loss = self.learning_agent.loss
 
     def step(self, handle, state, action, reward, next_state, done):
+        select = self.policy_selector.act(handle, state, 0.0)
+        self.policy_selector.step(handle, state, select, reward, next_state, done)
         self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done)
         self.learning_agent.step(handle, state, action, reward, next_state, done)
         self.loss = self.learning_agent.loss
 
     def act(self, handle, state, eps=0.):
-        agent = self.env.agents[handle]
-        position = agent.position
-        if position is None:
-            position = agent.initial_position
-        direction = agent.direction
-        if agent.status < RailAgentStatus.DONE:
-            agents_on_switch, agents_near_to_switch, _, _ = \
-                self.agent_can_choose_helper.check_agent_decision(position, direction)
-            if agents_on_switch or agents_near_to_switch:
-                return self.learning_agent.act(handle, state, eps)
-            else:
-                act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
-                return map_rail_env_action(act)
-        # Agent is still at target cell
-        return map_rail_env_action(RailEnvActions.DO_NOTHING)
+        select = self.policy_selector.act(handle, state, eps)
+        if select == 0:
+            return self.learning_agent.act(handle, state, eps)
+        return self.dead_lock_avoidance_agent.act(handle, state, -1.0)
 
     def save(self, filename):
         self.dead_lock_avoidance_agent.save(filename)
         self.learning_agent.save(filename)
+        self.policy_selector.save(filename + '.selector')
 
     def load(self, filename):
         self.dead_lock_avoidance_agent.load(filename)
         self.learning_agent.load(filename)
+        self.policy_selector.load(filename + '.selector')
 
     def start_step(self, train):
         self.dead_lock_avoidance_agent.start_step(train)
         self.learning_agent.start_step(train)
+        self.policy_selector.start_step(train)
 
     def end_step(self, train):
         self.dead_lock_avoidance_agent.end_step(train)
         self.learning_agent.end_step(train)
+        self.policy_selector.end_step(train)
 
     def start_episode(self, train):
         self.dead_lock_avoidance_agent.start_episode(train)
         self.learning_agent.start_episode(train)
+        self.policy_selector.start_episode(train)
 
     def end_episode(self, train):
         self.dead_lock_avoidance_agent.end_episode(train)
         self.learning_agent.end_episode(train)
+        self.policy_selector.end_episode(train)
 
     def load_replay_buffer(self, filename):
         self.dead_lock_avoidance_agent.load_replay_buffer(filename)
         self.learning_agent.load_replay_buffer(filename)
+        self.policy_selector.load_replay_buffer(filename + ".selector")
 
     def test(self):
         self.dead_lock_avoidance_agent.test()
         self.learning_agent.test()
+        self.policy_selector.test()
 
     def reset(self, env: RailEnv):
         self.env = env
-        self.agent_can_choose_helper.build_data(env)
         self.dead_lock_avoidance_agent.reset(env)
         self.learning_agent.reset(env)
+        self.policy_selector.reset(env)
 
     def clone(self):
         return self
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 6734f9c..882cf86 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -178,6 +178,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     if False:
         policy = DeadLockAvoidanceAgent(train_env, get_action_size())
     if True:
+        # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
         inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params)
         policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
     if False:
@@ -234,7 +235,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Reset environment
         reset_timer.start()
-        number_of_agents = n_agents  # int(min(n_agents, 1 + np.floor(episode_idx / 500)))
+        number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
         train_env_params.n_agents = episode_idx % number_of_agents + 1
 
         train_env = create_rail_env(train_env_params, tree_observation)
diff --git a/reinforcement_learning/multi_decision_agent.py b/reinforcement_learning/multi_decision_agent.py
index 13d1874..5047bcd 100644
--- a/reinforcement_learning/multi_decision_agent.py
+++ b/reinforcement_learning/multi_decision_agent.py
@@ -22,9 +22,9 @@ class MultiDecisionAgent(LearningPolicy):
 
 
     def step(self, handle, state, action, reward, next_state, done):
-        select = self.policy_selector.act(handle, state, 0.0)
         self.ppo_policy.step(handle, state, action, reward, next_state, done)
         self.dddqn_policy.step(handle, state, action, reward, next_state, done)
+        select = self.policy_selector.act(handle, state, 0.0)
         self.policy_selector.step(handle, state, select, reward, next_state, done)
 
     def act(self, handle, state, eps=0.):
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index 8172703..f0b6277 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -4,11 +4,10 @@ import numpy as np
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
 
 from utils.agent_can_choose_helper import AgentCanChooseHelper
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-from utils.deadlock_check import check_for_deadlock, get_agent_positions
+from utils.deadlock_check import get_agent_positions, get_agent_targets
 
 """
 LICENCE for the FastTreeObs Observation Builder  
@@ -27,7 +26,7 @@ class FastTreeObs(ObservationBuilder):
 
     def __init__(self, max_depth: Any):
         self.max_depth = max_depth
-        self.observation_dim = 41
+        self.observation_dim = 30
         self.agent_can_choose_helper = None
 
     def debug_render(self, env_renderer):
@@ -65,21 +64,18 @@ class FastTreeObs(ObservationBuilder):
         self.agent_can_choose_helper.build_data(self.env)
         self.debug_render_list = []
         self.debug_render_path_list = []
-        if self.env is not None:
-            self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False)
-        else:
-            self.dead_lock_avoidance_agent = None
 
     def _explore(self, handle, new_position, new_direction, distance_map, depth=0):
         has_opp_agent = 0
         has_same_agent = 0
         has_target = 0
+        has_opp_target = 0
         visited = []
         min_dist = distance_map[handle, new_position[0], new_position[1], new_direction]
 
         # stop exploring (max_depth reached)
         if depth >= self.max_depth:
-            return has_opp_agent, has_same_agent, has_target, visited, min_dist
+            return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
 
         # max_explore_steps = 100 -> just to ensure that the exploration ends
         cnt = 0
@@ -92,7 +88,7 @@ class FastTreeObs(ObservationBuilder):
                 if self.env.agents[opp_a].direction != new_direction:
                     # opp agent found -> stop exploring. This would be a strong signal.
                     has_opp_agent = 1
-                    return has_opp_agent, has_same_agent, has_target, visited, min_dist
+                    return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
                 else:
                     # same agent found
                     # the agent can follow the agent, because this agent is still moving ahead and there shouldn't
@@ -101,7 +97,8 @@ class FastTreeObs(ObservationBuilder):
                     # target on this branch -> thus the agents should scan further whether there will be an opposite
                     # agent walking on same track
                     has_same_agent = 1
-                    # !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited,min_dist
+                    # !NOT stop exploring!
+                    return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
 
             # agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration
             # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
@@ -112,10 +109,14 @@ class FastTreeObs(ObservationBuilder):
             if agents_near_to_switch:
                 # The exploration was walking on a path where the agent can not decide
                 # Best option would be MOVE_FORWARD -> Skip exploring - just walking
-                return has_opp_agent, has_same_agent, has_target, visited, min_dist
+                return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+
+            if self.env.agents[handle].target in self.agents_target:
+                has_opp_target = 1
 
             if self.env.agents[handle].target == new_position:
                 has_target = 1
+                return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
 
             possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
             if agents_on_switch:
@@ -130,30 +131,30 @@ class FastTreeObs(ObservationBuilder):
                     # --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
                     # we did in the TreeObservation (FLATLAND) ?
                     if possible_transitions[dir_loop] == 1:
-                        hoa, hsa, ht, v, m_dist = self._explore(handle,
-                                                                get_new_position(new_position, dir_loop),
-                                                                dir_loop,
-                                                                distance_map,
-                                                                depth + 1)
+                        hoa, hsa, ht, hot, v, m_dist = self._explore(handle,
+                                                                     get_new_position(new_position, dir_loop),
+                                                                     dir_loop,
+                                                                     distance_map,
+                                                                     depth + 1)
                         visited.append(v)
-                        has_opp_agent += max(hoa, has_opp_agent)
-                        has_same_agent += max(hsa, has_same_agent)
+                        has_opp_agent = max(hoa, has_opp_agent)
+                        has_same_agent = max(hsa, has_same_agent)
                         has_target = max(has_target, ht)
+                        has_opp_target = max(has_opp_target, hot)
                         min_dist = min(min_dist, m_dist)
-                return has_opp_agent, has_same_agent, has_target, visited, min_dist
+                return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
             else:
                 new_direction = fast_argmax(possible_transitions)
                 new_position = get_new_position(new_position, new_direction)
 
             min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction])
 
-        return has_opp_agent, has_same_agent, has_target, visited, min_dist
+        return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
 
     def get_many(self, handles: Optional[List[int]] = None):
-        self.dead_lock_avoidance_agent.start_step(train=False)
         self.agent_positions = get_agent_positions(self.env)
+        self.agents_target = get_agent_targets(self.env)
         observations = super().get_many(handles)
-        self.dead_lock_avoidance_agent.end_step(train=False)
         return observations
 
     def get(self, handle: int = 0):
@@ -184,8 +185,6 @@ class FastTreeObs(ObservationBuilder):
         # observation[23] : If there is a switch on the path which agent can not use -> 1
         # observation[24] : If there is a switch on the path which agent can not use -> 1
         # observation[25] : If there is a switch on the path which agent can not use -> 1
-        # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
-        # observation[27] : If there the agent can only walk forward or stop -> 1
 
         observation = np.zeros(self.observation_dim)
         visited = []
@@ -223,24 +222,18 @@ class FastTreeObs(ObservationBuilder):
                     if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
                         observation[dir_loop] = int(new_cell_dist < current_cell_dist)
 
-                    has_opp_agent, has_same_agent, has_target, v, min_dist = self._explore(handle,
-                                                                                           new_position,
-                                                                                           branch_direction,
-                                                                                           distance_map)
+                    has_opp_agent, has_same_agent, has_target, has_opp_target, v, min_dist = self._explore(handle,
+                                                                                                           new_position,
+                                                                                                           branch_direction,
+                                                                                                           distance_map)
                     visited.append(v)
 
                     if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)):
-                        observation[31 + dir_loop] = int(min_dist < current_cell_dist)
-                    observation[11 + dir_loop] = int(not np.math.isinf(new_cell_dist))
+                        observation[11 + dir_loop] = int(min_dist < current_cell_dist)
                     observation[15 + dir_loop] = has_opp_agent
                     observation[19 + dir_loop] = has_same_agent
                     observation[23 + dir_loop] = has_target
-                    observation[27 + dir_loop] = int(np.math.isinf(new_cell_dist))
-                    observation[36] = int(check_for_deadlock(handle,
-                                                             self.env,
-                                                             self.agent_positions,
-                                                             new_position,
-                                                             branch_direction))
+                    observation[27 + dir_loop] = has_opp_target
 
             agents_on_switch, \
             agents_near_to_switch, \
@@ -253,11 +246,6 @@ class FastTreeObs(ObservationBuilder):
             observation[9] = int(agents_near_to_switch)
             observation[10] = int(agents_near_to_switch_all)
 
-            action = self.dead_lock_avoidance_agent.act(handle, None, 0.0)
-            observation[35] = int(action == RailEnvActions.STOP_MOVING)
-
-            observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions))
-
         self.env.dev_obs_dict.update({handle: visited})
 
         observation[np.isinf(observation)] = -1
-- 
GitLab