From 02170adc31ca186b2a1a36ee060027f51e799c1b Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Fri, 23 Oct 2020 11:02:53 +0200
Subject: [PATCH] Q&D

---
 .../multi_agent_training.py                   |  44 ++-
 run_fast_methods.py                           |  26 ++
 utils/dead_lock_avoidance_agent.py            |   4 +-
 utils/extra.py                                | 251 +++++-------------
 4 files changed, 118 insertions(+), 207 deletions(-)
 create mode 100644 run_fast_methods.py

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index c47b484..c62a749 100644
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -191,7 +191,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
     if False:
         policy = ExtraPolicy(state_size, action_size)
     if False:
-        policy = PPOAgent(state_size, action_size, n_agents)
+        policy = PPOAgent(state_size, action_size, n_agents, train_env)
     if False:
         policy = MultiPolicy(state_size, action_size, n_agents, train_env)
     if False:
@@ -253,7 +253,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         reset_timer.start()
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
         policy.reset()
-        deadLockAvoidanceAgent = DeadLockAvoidanceAgent(train_env, state_size, action_size)
         reset_timer.end()
 
         if train_params.render:
@@ -274,24 +273,19 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         for step in range(max_steps - 1):
             inference_timer.start()
             policy.start_step()
-            deadLockAvoidanceAgent.start_step()
             for agent in train_env.get_agent_handles():
-                action = deadLockAvoidanceAgent.act(agent, None, 0.0)
                 update_values[agent] = False
-                if action != RailEnvActions.STOP_MOVING:
-                    if info['action_required'][agent]:
-                        update_values[agent] = True
-                        action = policy.act(agent, agent_obs[agent], eps=eps_start)
-                        action_count[action] += 1
-                        actions_taken.append(action)
-                    else:
-                        # An action is not required if the train hasn't joined the railway network,
-                        # if it already reached its target, or if is currently malfunctioning.
-                        action = 0
-
+                if info['action_required'][agent]:
+                    update_values[agent] = True
+                    action = policy.act(agent, agent_obs[agent], eps=eps_start)
+                    action_count[action] += 1
+                    actions_taken.append(action)
+                else:
+                    # An action is not required if the train hasn't joined the railway network,
+                    # if it already reached its target, or if is currently malfunctioning.
+                    action = 0
                 action_dict.update({agent: action})
             policy.end_step()
-            deadLockAvoidanceAgent.end_step()
 
             inference_timer.end()
 
@@ -464,26 +458,22 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
         score = 0.0
 
         obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
-        deadLockAvoidanceAgent = DeadLockAvoidanceAgent(env, None, None)
 
         final_step = 0
 
         for step in range(max_steps - 1):
-            deadLockAvoidanceAgent.start_step()
             for agent in env.get_agent_handles():
                 if tree_observation.check_is_observation_valid(agent_obs[agent]):
                     agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
                                                                                    observation_radius=observation_radius)
 
-                action = deadLockAvoidanceAgent.act(agent, None, 0)
-                if action != RailEnvActions.STOP_MOVING:
-                    if info['action_required'][agent]:
-                        if tree_observation.check_is_observation_valid(agent_obs[agent]):
-                            action = policy.act(agent, agent_obs[agent], eps=0.0)
+                action = RailEnvActions.DO_NOTHING
+                if info['action_required'][agent]:
+                    if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                        action = policy.act(agent, agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
 
             obs, all_rewards, done, info = env.step(action_dict)
-            deadLockAvoidanceAgent.end_step()
 
             for agent in env.get_agent_handles():
                 score += all_rewards[agent]
@@ -517,9 +507,9 @@ if __name__ == "__main__":
                         type=int)
     parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int)
-    parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
+    parser.add_argument("--eps_start", help="max exploration", default=0.5, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.0001, type=float)
-    parser.add_argument("--eps_decay", help="exploration decay", default=0.999, type=float)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.9997, type=float)
     parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int)
     parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
     parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
@@ -535,8 +525,8 @@ if __name__ == "__main__":
     parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
     parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool)
-    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
     parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool)
+    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
     parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
     parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
 
diff --git a/run_fast_methods.py b/run_fast_methods.py
new file mode 100644
index 0000000..dbb5ea1
--- /dev/null
+++ b/run_fast_methods.py
@@ -0,0 +1,26 @@
+from time import time
+
+import numpy as np
+from flatland.envs.rail_env import fast_isclose
+
+
+def print_timing(label, start_time, end_time):
+    print("{:>10.4f}ms".format(1000 * (end_time - start_time)) + "\t" + label)
+
+
+def check_isclose(nbr=100000):
+    s = time()
+    for x in range(nbr):
+        fast_isclose(x, 0.0, rtol=1e-03)
+    e = time()
+    print_timing("fast_isclose", start_time=s, end_time=e)
+
+    s = time()
+    for x in range(nbr):
+        np.isclose(x, 0.0, rtol=1e-03)
+    e = time()
+    print_timing("np.isclose", start_time=s, end_time=e)
+
+
+if __name__ == "__main__":
+    check_isclose()
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index bb9dc3d..71140c3 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -117,6 +117,7 @@ class DeadLockAvoidanceAgent(Policy):
                 my_walker.walk_to_target(handle)
         shortest_distance_agent_map, full_shortest_distance_agent_map = my_walker.getData()
 
+        delta_data = np.copy(full_shortest_distance_agent_map)
         self.agent_can_move = {}
         agent_positions_map = (agent_positions > -1).astype(int)
         for handle in range(self.env.get_num_agents()):
@@ -129,6 +130,7 @@ class DeadLockAvoidanceAgent(Policy):
             for opp_a in opp_agents:
                 opp = full_shortest_distance_agent_map[opp_a]
                 delta = ((delta - opp - agent_positions_map) > 0).astype(int)
+                delta_data[handle] += np.clip(delta,0,1)
                 if (np.sum(delta) < 3):
                     next_step_ok = False
 
@@ -140,7 +142,7 @@ class DeadLockAvoidanceAgent(Policy):
             b = np.ceil(self.env.get_num_agents() / a)
             for handle in range(self.env.get_num_agents()):
                 plt.subplot(a, b, handle + 1)
-                plt.imshow(shortest_distance_agent_map[handle])
+                plt.imshow(delta_data[handle])
             # plt.colorbar()
             plt.show(block=False)
             plt.pause(0.01)
diff --git a/utils/extra.py b/utils/extra.py
index 83263c7..4b14b84 100644
--- a/utils/extra.py
+++ b/utils/extra.py
@@ -1,11 +1,11 @@
-# import matplotlib.pyplot as plt
 import numpy as np
+
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.rail_env import RailEnvActions, RailAgentStatus, RailEnv
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnvActions, fast_argmax, fast_count_nonzero
 
 from reinforcement_learning.policy import Policy
-from utils.shortest_Distance_walker import ShortestDistanceWalker
 
 
 class ExtraPolicy(Policy):
@@ -52,53 +52,11 @@ class ExtraPolicy(Policy):
         pass
 
 
-def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
-    if possible_transitions[0] == 1:
-        return 0
-    if possible_transitions[1] == 1:
-        return 1
-    if possible_transitions[2] == 1:
-        return 2
-    return 3
-
-
-def fast_count_nonzero(possible_transitions: (int, int, int, int)):
-    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
-
-
 class Extra(ObservationBuilder):
 
     def __init__(self, max_depth):
         self.max_depth = max_depth
-        self.observation_dim = 62
-
-    def shortest_distance_mapper(self):
-
-        class MyWalker(ShortestDistanceWalker):
-            def __init__(self, env: RailEnv):
-                super().__init__(env)
-                self.shortest_distance_agent_counter = np.zeros((self.env.height, self.env.width), dtype=int)
-                self.shortest_distance_agent_direction_counter = np.zeros((self.env.height, self.env.width, 4),
-                                                                          dtype=int)
-
-            def getData(self):
-                return self.shortest_distance_agent_counter, self.shortest_distance_agent_direction_counter
-
-            def callback(self, handle, agent, position, direction, action, possible_transitions):
-                self.shortest_distance_agent_counter[position] += 1
-                self.shortest_distance_agent_direction_counter[(position[0], position[1], direction)] += 1
-
-        my_walker = MyWalker(self.env)
-        for handle in range(self.env.get_num_agents()):
-            agent = self.env.agents[handle]
-            if agent.status <= RailAgentStatus.ACTIVE:
-                my_walker.walk_to_target(handle)
-
-        self.shortest_distance_agent_counter, self.shortest_distance_agent_direction_counter = my_walker.getData()
-
-        # plt.imshow(self.shortest_distance_agent_counter)
-        # plt.colorbar()
-        # plt.show()
+        self.observation_dim = 26
 
     def build_data(self):
         if self.env is not None:
@@ -109,13 +67,6 @@ class Extra(ObservationBuilder):
         self.debug_render_path_list = []
         if self.env is not None:
             self.find_all_cell_where_agent_can_choose()
-            self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
-            self.history_direction = np.zeros((self.env.height, self.env.width), dtype=int) - 1
-            self.history_same_direction_cnt = np.zeros((self.env.height, self.env.width), dtype=int)
-            self.history_time = np.zeros((self.env.height, self.env.width), dtype=int) - 1
-
-        self.shortest_distance_agent_counter = None
-        self.shortest_distance_agent_direction_counter = None
 
     def find_all_cell_where_agent_can_choose(self):
 
@@ -238,39 +189,32 @@ class Extra(ObservationBuilder):
             return 2
         return 3
 
-    def _explore(self, handle, new_position, new_direction, distance_map, depth):
+    def _explore(self, handle, new_position, new_direction, depth=0):
 
-        may_has_opp_agent = 0
-        has_opp_agent = -1
-        has_other_target = 0
-        has_target = 0
+        has_opp_agent = 0
+        has_same_agent = 0
+        has_switch = 0
         visited = []
 
-        new_cell_dist = np.inf
-
         # stop exploring (max_depth reached)
-        if depth > self.max_depth:
-            return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+        if depth >= self.max_depth:
+            return has_opp_agent, has_same_agent, has_switch, visited
 
         # max_explore_steps = 100
         cnt = 0
         while cnt < 100:
             cnt += 1
-            has_other_target = int(new_position in self.agent_targets)
-            new_cell_dist = min(new_cell_dist, distance_map[handle,
-                                                            new_position[0], new_position[1],
-                                                            new_direction])
 
             visited.append(new_position)
-            has_target = int(self.env.agents[handle].target == new_position)
-            opp_a = self.agent_positions[new_position]
+            opp_a = self.env.agent_positions[new_position]
             if opp_a != -1 and opp_a != handle:
-                possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
-                if possible_transitions[self.env.agents[opp_a].direction] < 1:
+                if self.env.agents[opp_a].direction != new_direction:
                     # opp agent found
-                    has_opp_agent = opp_a
-                    may_has_opp_agent = 1
-                    return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+                    has_opp_agent = 1
+                    return has_opp_agent, has_same_agent, has_switch, visited
+                else:
+                    has_same_agent = 1
+                    return has_opp_agent, has_same_agent, has_switch, visited
 
             # convert one-hot encoding to 0,1,2,3
             agents_on_switch, \
@@ -278,40 +222,32 @@ class Extra(ObservationBuilder):
             agents_near_to_switch_all, \
             agents_on_switch_all = \
                 self.check_agent_descision(new_position, new_direction)
-
             if agents_near_to_switch:
-                return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+                return has_opp_agent, has_same_agent, has_switch, visited
 
             possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
-            if fast_count_nonzero(possible_transitions) > 1:
-                may_has_opp_agent_loop = 1
+            if agents_on_switch:
+                f = 0
                 for dir_loop in range(4):
                     if possible_transitions[dir_loop] == 1:
-                        hoa, mhoa, hot, ht, v, min_cell_dist = self._explore(handle,
-                                                                             get_new_position(new_position,
-                                                                                              dir_loop),
-                                                                             dir_loop,
-                                                                             distance_map,
-                                                                             depth + 1)
-
-                        has_opp_agent = max(has_opp_agent, hoa)
-                        may_has_opp_agent_loop = min(may_has_opp_agent_loop, mhoa)
-                        has_other_target = max(has_other_target, hot)
-                        has_target = max(has_target, ht)
+                        f += 1
+                        hoa, hsa, hs, v = self._explore(handle,
+                                                        get_new_position(new_position, dir_loop),
+                                                        dir_loop,
+                                                        depth + 1)
                         visited.append(v)
-                        new_cell_dist = min(min_cell_dist, new_cell_dist)
-                return has_opp_agent, may_has_opp_agent_loop, has_other_target, has_target, visited, new_cell_dist
+                        has_opp_agent += hoa
+                        has_same_agent += hsa
+                        has_switch += hs
+                f = max(f, 1.0)
+                return has_opp_agent / f, has_same_agent / f, has_switch / f, visited
             else:
                 new_direction = fast_argmax(possible_transitions)
                 new_position = get_new_position(new_position, new_direction)
 
-        return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+        return has_opp_agent, has_same_agent, has_switch, visited
 
     def get(self, handle):
-
-        if (handle == 0):
-            self.updateSharedData()
-
         # all values are [0,1]
         # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
         # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
@@ -319,10 +255,26 @@ class Extra(ObservationBuilder):
         # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
         # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
         # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
-        # observation[6] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
-        # observation[7] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
-        # observation[8] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
-        # observation[9] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
+        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
+        # observation[7]  : current agent is located at a switch, where it can take a routing decision
+        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
+        # observation[9]  : current agent is located one step before/after a switch
+        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
+        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
+        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
+        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
+        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
+        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
+        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
+        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
+        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
+        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
+        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
+        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
+        # observation[22] : If there is a switch on the path which agent can not use -> 1
+        # observation[23] : If there is a switch on the path which agent can not use -> 1
+        # observation[24] : If there is a switch on the path which agent can not use -> 1
+        # observation[25] : If there is a switch on the path which agent can not use -> 1
 
         observation = np.zeros(self.observation_dim)
         visited = []
@@ -331,11 +283,12 @@ class Extra(ObservationBuilder):
         agent_done = False
         if agent.status == RailAgentStatus.READY_TO_DEPART:
             agent_virtual_position = agent.initial_position
-            observation[0] = 1
+            observation[4] = 1
         elif agent.status == RailAgentStatus.ACTIVE:
             agent_virtual_position = agent.position
-            observation[1] = 1
+            observation[5] = 1
         else:
+            observation[6] = 1
             agent_virtual_position = (-1, -1)
             agent_done = True
 
@@ -356,90 +309,30 @@ class Extra(ObservationBuilder):
                     new_cell_dist = distance_map[handle,
                                                  new_position[0], new_position[1],
                                                  branch_direction]
-
-                    has_opp_agent, \
-                    may_has_opp_agent, \
-                    has_other_target, \
-                    has_target, \
-                    v, \
-                    min_cell_dist = self._explore(handle,
-                                                  new_position,
-                                                  branch_direction,
-                                                  distance_map,
-                                                  0)
                     if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
-                        observation[2 + dir_loop] = int(new_cell_dist < current_cell_dist)
-
-                    new_cell_dist = min(min_cell_dist, new_cell_dist)
-                    if not (np.math.isinf(new_cell_dist) and not np.math.isinf(current_cell_dist)):
-                        observation[6 + dir_loop] = int(new_cell_dist < current_cell_dist)
+                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
 
+                    has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
                     visited.append(v)
 
-                    observation[10 + dir_loop] = int(has_opp_agent > -1)
-                    observation[14 + dir_loop] = may_has_opp_agent
-                    observation[18 + dir_loop] = has_other_target
-                    observation[22 + dir_loop] = has_target
-                    observation[26 + dir_loop] = self.getHistorySameDirection(new_position, branch_direction)
-                    observation[30 + dir_loop] = self.getHistoryOppositeDirection(new_position, branch_direction)
-                    observation[34 + dir_loop] = self.getTemporalDistance(new_position)
-                    observation[38 + dir_loop] = self.getFlowDensity(new_position)
-                    observation[42 + dir_loop] = self.getDensitySameDirection(new_position, branch_direction)
-                    observation[44 + dir_loop] = self.getDensity(new_position)
-                    observation[48 + dir_loop] = int(not np.math.isinf(new_cell_dist))
-                    observation[52 + dir_loop] = 1
-                    observation[54 + dir_loop] = int(has_opp_agent > handle)
+                    observation[10 + dir_loop] = 1
+                    observation[14 + dir_loop] = has_opp_agent
+                    observation[18 + dir_loop] = has_same_agent
+                    observation[22 + dir_loop] = has_switch
+
+        agents_on_switch, \
+        agents_near_to_switch, \
+        agents_near_to_switch_all, \
+        agents_on_switch_all = \
+            self.check_agent_descision(agent_virtual_position, agent.direction)
+        observation[7] = int(agents_on_switch)
+        observation[8] = int(agents_near_to_switch)
+        observation[9] = int(agents_near_to_switch_all)
 
         self.env.dev_obs_dict.update({handle: visited})
 
         return observation
 
-    def getDensitySameDirection(self, position, direction):
-        val = self.shortest_distance_agent_direction_counter[(position[0], position[1], direction)]
-        return val / self.env.get_num_agents()
-
-    def getDensity(self, position):
-        val = self.shortest_distance_agent_counter[position]
-        return val / self.env.get_num_agents()
-
-    def getHistorySameDirection(self, position, direction):
-        val = self.history_direction[position]
-        if val == -1:
-            return -1
-        if val == direction:
-            return 1
-        return 0
-
-    def getHistoryOppositeDirection(self, position, direction):
-        val = self.getHistorySameDirection(position, direction)
-        if val == -1:
-            return -1
-        return 1 - val
-
-    def getTemporalDistance(self, position):
-        if self.history_time[position] == -1:
-            return -1
-        val = self.env._elapsed_steps - self.history_time[position]
-        if val < 1:
-            return 0
-        return 1 + np.log(1 + val)
-
-    def getFlowDensity(self, position):
-        val = self.env._elapsed_steps - self.history_same_direction_cnt[position]
-        return 1 + np.log(1 + val)
-
-    def updateSharedData(self):
-        self.shortest_distance_mapper()
-        self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
-        self.agent_targets = []
-        for a in np.arange(self.env.get_num_agents()):
-            if self.env.agents[a].status == RailAgentStatus.ACTIVE:
-                self.agent_targets.append(self.env.agents[a].target)
-                if self.env.agents[a].position is not None:
-                    self.agent_positions[self.env.agents[a].position] = a
-                    if self.history_direction[self.env.agents[a].position] == self.env.agents[a].direction:
-                        self.history_same_direction_cnt[self.env.agents[a].position] += 1
-                    else:
-                        self.history_same_direction_cnt[self.env.agents[a].position] = 0
-                    self.history_direction[self.env.agents[a].position] = self.env.agents[a].direction
-                    self.history_time[self.env.agents[a].position] = self.env._elapsed_steps
+    @staticmethod
+    def agent_can_choose(observation):
+        return observation[7] == 1 or observation[8] == 1
-- 
GitLab