From d6103087ab59c61e8e0eb449385594833ccf5e7a Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 3 Nov 2020 16:05:33 +0100
Subject: [PATCH] ...

---
 reinforcement_learning/dddqn_policy.py        |   6 +-
 .../multi_agent_training.py                   |   4 +-
 run.py                                        |   4 +-
 utils/dead_lock_avoidance_agent.py            | 175 ++++++++++++++++++
 utils/fast_tree_obs.py                        |   6 +-
 5 files changed, 185 insertions(+), 10 deletions(-)
 create mode 100644 utils/dead_lock_avoidance_agent.py

diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 32d7110..c1177b9 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -123,7 +123,11 @@ class DDDQNPolicy(Policy):
             self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
             self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
         else:
-            raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target"))
+            if os.path.exists(filename):
+                self.qnetwork_local.load_state_dict(torch.load(filename))
+                self.qnetwork_target.load_state_dict(torch.load(filename))
+            else:
+                raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target"))
 
     def save_replay_buffer(self, filename):
         memory = self.memory.memory
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 195b46a..9c78a72 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -322,7 +322,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Print logs
         if episode_idx % checkpoint_interval == 0:
-            torch.save(policy.qnetwork_local, './checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
+            policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
 
             if save_replay_buffer:
                 policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl')
@@ -475,7 +475,7 @@ if __name__ == "__main__":
     parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0, type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
                         type=int)
-    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=50, type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
     parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
diff --git a/run.py b/run.py
index a2309a6..918d78c 100644
--- a/run.py
+++ b/run.py
@@ -28,7 +28,7 @@ from utils.observation_utils import normalize_observation
 VERBOSE = True
 
 # Checkpoint to use (remember to push it!)
-checkpoint = "checkpoints/201014015722-1500.pth"
+checkpoint = "checkpoints/201103150429-2500.pth"
 
 # Use last action cache
 USE_ACTION_CACHE = True
@@ -55,7 +55,7 @@ action_size = 5
 policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
 
 if os.path.isfile(checkpoint):
-    policy.qnetwork_local = torch.load(checkpoint)
+    policy.load(checkpoint)
 else:
     print("Checkpoint not found, using untrained policy! (path: {})".format(checkpoint))
 
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
new file mode 100644
index 0000000..c7c6d8a
--- /dev/null
+++ b/utils/dead_lock_avoidance_agent.py
@@ -0,0 +1,175 @@
+from typing import Optional, List
+
+import matplotlib.pyplot as plt
+import numpy as np
+from flatland.core.env_observation_builder import DummyObservationBuilder
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
+
+from reinforcement_learning.policy import Policy
+from utils.shortest_Distance_walker import ShortestDistanceWalker
+
+
+class DeadlockAvoidanceObservation(DummyObservationBuilder):
+    def __init__(self):
+        self.counter = 0
+
+    def get_many(self, handles: Optional[List[int]] = None) -> bool:
+        self.counter += 1
+        obs = np.ones(len(handles), 2)
+        for handle in handles:
+            obs[handle][0] = handle
+            obs[handle][1] = self.counter
+        return obs
+
+
+class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
+    def __init__(self, env: RailEnv, agent_positions, switches):
+        super().__init__(env)
+        self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
+                                                     self.env.height,
+                                                     self.env.width),
+                                                    dtype=int) - 1
+
+        self.full_shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
+                                                          self.env.height,
+                                                          self.env.width),
+                                                         dtype=int) - 1
+
+        self.agent_positions = agent_positions
+
+        self.opp_agent_map = {}
+        self.same_agent_map = {}
+        self.switches = switches
+
+    def getData(self):
+        return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map
+
+    def callback(self, handle, agent, position, direction, action, possible_transitions):
+        opp_a = self.agent_positions[position]
+        if opp_a != -1 and opp_a != handle:
+            if self.env.agents[opp_a].direction != direction:
+                d = self.opp_agent_map.get(handle, [])
+                if opp_a not in d:
+                    d.append(opp_a)
+                self.opp_agent_map.update({handle: d})
+            else:
+                if len(self.opp_agent_map.get(handle, [])) == 0:
+                    d = self.same_agent_map.get(handle, [])
+                    if opp_a not in d:
+                        d.append(opp_a)
+                    self.same_agent_map.update({handle: d})
+
+        if len(self.opp_agent_map.get(handle, [])) == 0:
+            if self.switches.get(position, None) is None:
+                self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
+        self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
+
+
+class DeadLockAvoidanceAgent(Policy):
+    def __init__(self, env: RailEnv, show_debug_plot=False):
+        self.env = env
+        self.memory = None
+        self.loss = 0
+        self.agent_can_move = {}
+        self.switches = {}
+        self.show_debug_plot = show_debug_plot
+
+    def step(self, state, action, reward, next_state, done):
+        pass
+
+    def act(self, state, eps=0.):
+        # agent = self.env.agents[state[0]]
+        check = self.agent_can_move.get(state[0], None)
+        if check is None:
+            return RailEnvActions.STOP_MOVING
+        return check[3]
+
+    def reset(self):
+        self.agent_positions = None
+        self.shortest_distance_walker = None
+        self.switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in self.switches.keys():
+                            self.switches.update({pos: [dir]})
+                        else:
+                            self.switches[pos].append(dir)
+
+    def start_step(self):
+        self.build_agent_position_map()
+        self.shortest_distance_mapper()
+        self.extract_agent_can_move()
+
+    def end_step(self):
+        pass
+
+    def get_actions(self):
+        pass
+
+    def build_agent_position_map(self):
+        # build map with agent positions (only active agents)
+        self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status == RailAgentStatus.ACTIVE:
+                if agent.position is not None:
+                    self.agent_positions[agent.position] = handle
+
+    def shortest_distance_mapper(self):
+        self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env,
+                                                                                self.agent_positions,
+                                                                                self.switches)
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status <= RailAgentStatus.ACTIVE:
+                self.shortest_distance_walker.walk_to_target(handle)
+
+    def extract_agent_can_move(self):
+        self.agent_can_move = {}
+        shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData()
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status < RailAgentStatus.DONE:
+                next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle],
+                                                         self.shortest_distance_walker.same_agent_map.get(handle, []),
+                                                         self.shortest_distance_walker.opp_agent_map.get(handle, []),
+                                                         full_shortest_distance_agent_map)
+                if next_step_ok:
+                    next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle)
+                    self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]})
+
+        if self.show_debug_plot:
+            a = np.floor(np.sqrt(self.env.get_num_agents()))
+            b = np.ceil(self.env.get_num_agents() / a)
+            for handle in range(self.env.get_num_agents()):
+                plt.subplot(a, b, handle + 1)
+                plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle])
+            plt.show(block=False)
+            plt.pause(0.01)
+
+    def check_agent_can_move(self,
+                             my_shortest_walking_path,
+                             same_agents,
+                             opp_agents,
+                             full_shortest_distance_agent_map):
+        agent_positions_map = (self.agent_positions > -1).astype(int)
+        delta = my_shortest_walking_path
+        next_step_ok = True
+        for opp_a in opp_agents:
+            opp = full_shortest_distance_agent_map[opp_a]
+            delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int)
+            if np.sum(delta) < (3 + len(opp_agents)):
+                next_step_ok = False
+        return next_step_ok
+
+    def save(self, filename):
+        pass
+
+    def load(self, filename):
+        pass
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index 12c91ca..7e4c934 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -294,8 +294,4 @@ class FastTreeObs(ObservationBuilder):
 
         self.env.dev_obs_dict.update({handle: visited})
 
-        return observation
-
-    @staticmethod
-    def agent_can_choose(observation):
-        return observation[7] == 1 or observation[8] == 1
+        return observation
\ No newline at end of file
-- 
GitLab