From 3f2468c23685eabffc2a478758a87c9f191b2ed9 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Wed, 28 Oct 2020 10:48:33 +0100
Subject: [PATCH] DeadLockAvoidance used for extra obs (current
 position/direction check and as well for branching checks (one step ahead)

---
 .../multi_agent_training.py                   | 22 +++++++++---
 utils/dead_lock_avoidance_agent.py            |  2 +-
 utils/extra.py                                | 36 +++++++++----------
 3 files changed, 34 insertions(+), 26 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index dc4fc33..542f587 100644
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -9,6 +9,7 @@ from pprint import pprint
 
 import numpy as np
 import psutil
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -196,8 +197,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         policy = PPOAgent(state_size, action_size, n_agents, train_env)
     if False:
         policy = MultiPolicy(state_size, action_size, n_agents, train_env)
-    if False:
-        policy = DeadLockAvoidanceAgent(train_env, state_size, action_size)
 
     # Load existing policy
     if train_params.load_policy is not None:
@@ -244,6 +243,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
             training_id
         ))
 
+    rl_policy = policy
     for episode_idx in range(n_episodes + 1):
         step_timer = Timer()
         reset_timer = Timer()
@@ -254,6 +254,18 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         # Reset environment
         reset_timer.start()
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
+
+        # train different number of agents : 1,2,3,... n_agents
+        for handle in range(train_env.get_num_agents()):
+            if (episode_idx % n_agents) < handle:
+                train_env.agents[handle].status = RailAgentStatus.DONE_REMOVED
+
+        # start with simple deadlock avoidance agent policy (imitation learning?)
+        if episode_idx < 500:
+            policy = DeadLockAvoidanceAgent(train_env, state_size, action_size)
+        else:
+            policy = rl_policy
+
         policy.reset()
         reset_timer.end()
 
@@ -512,7 +524,7 @@ if __name__ == "__main__":
     parser.add_argument("--eps_start", help="max exploration", default=0.5, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
     parser.add_argument("--eps_decay", help="exploration decay", default=0.9985, type=float)
-    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int)
+    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e6), type=int)
     parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
     parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
     parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
@@ -528,8 +540,8 @@ if __name__ == "__main__":
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
     parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool)
     parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool)
-    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
+    parser.add_argument("--max_depth", help="max depth", default=2, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
                         type=int)
     parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
 
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 62e76ea..39090a1 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -148,6 +148,6 @@ class DeadLockAvoidanceAgent(Policy):
         for opp_a in opp_agents:
             opp = full_shortest_distance_agent_map[opp_a]
             delta = ((delta - opp - agent_positions_map) > 0).astype(int)
-            if (np.sum(delta) < 1 + len(opp_agents)):
+            if (np.sum(delta) < 2 + len(opp_agents)):
                 next_step_ok = False
         return next_step_ok
diff --git a/utils/extra.py b/utils/extra.py
index 340cb40..89ed0bb 100644
--- a/utils/extra.py
+++ b/utils/extra.py
@@ -183,14 +183,21 @@ class Extra(ObservationBuilder):
         self.build_data()
         return
 
-    def fast_argmax(self, array):
-        if array[0] == 1:
-            return 0
-        if array[1] == 1:
-            return 1
-        if array[2] == 1:
-            return 2
-        return 3
+
+    def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction):
+        _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
+        opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
+        local_walker = DeadlockAvoidanceShortestDistanceWalker(
+            self.env,
+            self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
+            self.dead_lock_avoidance_agent.shortest_distance_walker.switches)
+        local_walker.walk_to_target(handle, new_position, branch_direction)
+        shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
+        my_shortest_path_to_check = shortest_distance_agent_map[handle]
+        next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
+                                                                           opp_agents,
+                                                                           full_shortest_distance_agent_map)
+        return next_step_ok
 
     def _explore(self, handle, new_position, new_direction, depth=0):
 
@@ -332,18 +339,7 @@ class Extra(ObservationBuilder):
                     observation[18 + dir_loop] = has_same_agent
                     observation[22 + dir_loop] = has_switch
 
-                    _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
-                    opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
-                    local_walker = DeadlockAvoidanceShortestDistanceWalker(
-                        self.env,
-                        self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
-                        self.dead_lock_avoidance_agent.shortest_distance_walker.switches)
-                    local_walker.walk_to_target(handle, new_position, branch_direction)
-                    shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
-                    my_shortest_path_to_check = shortest_distance_agent_map[handle]
-                    next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
-                                                                                       opp_agents,
-                                                                                       full_shortest_distance_agent_map)
+                    next_step_ok = self._check_dead_lock_at_branching_position(handle, new_position, branch_direction)
                     if next_step_ok:
                         observation[26 + dir_loop] = 1
 
-- 
GitLab