From 225bbed91e9961f4ac5599b05e3b53e97896af72 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 17 Nov 2020 10:49:09 +0100
Subject: [PATCH] refactored file name

---
 reinforcement_learning/dddqn_policy.py        |  4 -
 .../multi_agent_training.py                   | 23 ++++--
 run.py                                        |  5 +-
 utils/extra.py                                |  2 +
 utils/fast_tree_obs.py                        | 81 +++++++++----------
 5 files changed, 60 insertions(+), 55 deletions(-)

diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index b34dd36..3eb54a3 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -66,10 +66,6 @@ class DDDQNPolicy(Policy):
         # Epsilon-greedy action selection
         if random.random() >= eps:
             return np.argmax(action_values.cpu().data.numpy())
-            qvals = action_values.cpu().data.numpy()[0]
-            qvals = qvals - np.min(qvals)
-            qvals = qvals / (1e-5 + np.sum(qvals))
-            return np.argmax(np.random.multinomial(1, qvals))
         else:
             return random.choice(np.arange(self.action_size))
 
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index be905e0..b3271b5 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -171,9 +171,14 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
     completion_window = deque(maxlen=checkpoint_interval)
 
+    # IF USE_SINGLE_AGENT_TRAINING is set and the episode_idx <= MAX_SINGLE_TRAINING_ITERATION then
+    # the training gets done with single use. Each UPDATE_POLICY2_N_EPISODE the second policy get replaced
+    # with the policy (the one which get trained).
+    USE_SINGLE_AGENT_TRAINING = True
+    MAX_SINGLE_TRAINING_ITERATION = 1000
+    UPDATE_POLICY2_N_EPISODE = 200
+
     # Double Dueling DQN policy
-    USE_SINGLE_AGENT_TRAINING = False
-    UPDATE_POLICY2_N_EPISODE = 1000
     policy = DDDQNPolicy(state_size, action_size, train_params)
     # policy = PPOAgent(state_size, action_size, n_agents)
     # Load existing policy
@@ -221,6 +226,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         preproc_timer = Timer()
         inference_timer = Timer()
 
+        if episode_idx > MAX_SINGLE_TRAINING_ITERATION:
+            USE_SINGLE_AGENT_TRAINING = False
+
         # Reset environment
         reset_timer.start()
         train_env_params.n_agents = episode_idx % n_agents + 1
@@ -293,6 +301,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                         if agent_obs[agent][26] == 1:
                             if act != RailEnvActions.STOP_MOVING:
                                 all_rewards[agent] -= 10.0
+                        if agent_obs[agent][27] == 1:
+                            if act == RailEnvActions.MOVE_LEFT or \
+                                    act == RailEnvActions.MOVE_RIGHT or \
+                                    act == RailEnvActions.DO_NOTHING:
+                                all_rewards[agent] -= 1.0
 
             step_timer.end()
 
@@ -310,7 +323,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 if update_values[agent] or done['__all__']:
                     # Only learn from timesteps where somethings happened
                     learn_timer.start()
-                    if agent in agent_to_learn:
+                    if agent in agent_to_learn or not USE_SINGLE_AGENT_TRAINING:
                         policy.step(agent,
                                     agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
                                     agent_obs[agent],
@@ -501,8 +514,8 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=54000, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
                         type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
                         type=int)
diff --git a/run.py b/run.py
index 048dcde..a4fa62f 100644
--- a/run.py
+++ b/run.py
@@ -27,12 +27,11 @@ VERBOSE = True
 
 # Checkpoint to use (remember to push it!)
 # checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10
-checkpoint = "./checkpoints/201113211844-6700.pth" # 19.690047767961005 DEPTH=2 AGENTS=20
-
+checkpoint = "./checkpoints/201117082153-1500.pth" # 21.570149424415636 DEPTH=2 AGENTS=10
 
 # Use last action cache
 USE_ACTION_CACHE = False
-USE_DEAD_LOCK_AVOIDANCE_AGENT = False
+USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213
 
 # Observation parameters (must match training parameters!)
 observation_tree_depth = 2
diff --git a/utils/extra.py b/utils/extra.py
index 89ed0bb..c4df6a8 100644
--- a/utils/extra.py
+++ b/utils/extra.py
@@ -187,6 +187,7 @@ class Extra(ObservationBuilder):
     def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction):
         _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
         opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
+        same_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.same_agent_map.get(handle,[])
         local_walker = DeadlockAvoidanceShortestDistanceWalker(
             self.env,
             self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
@@ -196,6 +197,7 @@ class Extra(ObservationBuilder):
         my_shortest_path_to_check = shortest_distance_agent_map[handle]
         next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
                                                                            opp_agents,
+                                                                           same_agents,
                                                                            full_shortest_distance_agent_map)
         return next_step_ok
 
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index c388d2a..625a21e 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -25,7 +25,7 @@ class FastTreeObs(ObservationBuilder):
 
     def __init__(self, max_depth):
         self.max_depth = max_depth
-        self.observation_dim = 27
+        self.observation_dim = 32
 
     def build_data(self):
         if self.env is not None:
@@ -40,8 +40,8 @@ class FastTreeObs(ObservationBuilder):
         else:
             self.dead_lock_avoidance_agent = None
 
-    def find_all_cell_where_agent_can_choose(self):
-        switches = {}
+    def find_all_switches(self):
+        self.switches = {}
         for h in range(self.env.height):
             for w in range(self.env.width):
                 pos = (h, w)
@@ -49,12 +49,13 @@ class FastTreeObs(ObservationBuilder):
                     possible_transitions = self.env.rail.get_transitions(*pos, dir)
                     num_transitions = fast_count_nonzero(possible_transitions)
                     if num_transitions > 1:
-                        if pos not in switches.keys():
-                            switches.update({pos: [dir]})
+                        if pos not in self.switches.keys():
+                            self.switches.update({pos: [dir]})
                         else:
-                            switches[pos].append(dir)
+                            self.switches[pos].append(dir)
 
-        switches_neighbours = {}
+    def find_all_switch_neighbours(self):
+        self.switches_neighbours = {}
         for h in range(self.env.height):
             for w in range(self.env.width):
                 # look one step forward
@@ -64,35 +65,34 @@ class FastTreeObs(ObservationBuilder):
                     for d in range(4):
                         if possible_transitions[d] == 1:
                             new_cell = get_new_position(pos, d)
-                            if new_cell in switches.keys() and pos not in switches.keys():
-                                if pos not in switches_neighbours.keys():
-                                    switches_neighbours.update({pos: [dir]})
+                            if new_cell in self.switches.keys() and pos not in self.switches.keys():
+                                if pos not in self.switches_neighbours.keys():
+                                    self.switches_neighbours.update({pos: [dir]})
                                 else:
-                                    switches_neighbours[pos].append(dir)
+                                    self.switches_neighbours[pos].append(dir)
 
-        self.switches = switches
-        self.switches_neighbours = switches_neighbours
+    def find_all_cell_where_agent_can_choose(self):
+        self.find_all_switches()
+        self.find_all_switch_neighbours()
 
     def check_agent_decision(self, position, direction):
-        switches = self.switches
-        switches_neighbours = self.switches_neighbours
         agents_on_switch = False
         agents_on_switch_all = False
         agents_near_to_switch = False
         agents_near_to_switch_all = False
-        if position in switches.keys():
-            agents_on_switch = direction in switches[position]
+        if position in self.switches.keys():
+            agents_on_switch = direction in self.switches[position]
             agents_on_switch_all = True
 
-        if position in switches_neighbours.keys():
+        if position in self.switches_neighbours.keys():
             new_cell = get_new_position(position, direction)
-            if new_cell in switches.keys():
-                if not direction in switches[new_cell]:
-                    agents_near_to_switch = direction in switches_neighbours[position]
+            if new_cell in self.switches.keys():
+                if not direction in self.switches[new_cell]:
+                    agents_near_to_switch = direction in self.switches_neighbours[position]
             else:
-                agents_near_to_switch = direction in switches_neighbours[position]
+                agents_near_to_switch = direction in self.switches_neighbours[position]
 
-            agents_near_to_switch_all = direction in switches_neighbours[position]
+            agents_near_to_switch_all = direction in self.switches_neighbours[position]
 
         return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
 
@@ -151,15 +151,6 @@ class FastTreeObs(ObservationBuilder):
         self.build_data()
         return
 
-    def fast_argmax(self, array):
-        if array[0] == 1:
-            return 0
-        if array[1] == 1:
-            return 1
-        if array[2] == 1:
-            return 2
-        return 3
-
     def _explore(self, handle, new_position, new_direction, depth=0):
         has_opp_agent = 0
         has_same_agent = 0
@@ -269,6 +260,7 @@ class FastTreeObs(ObservationBuilder):
         # observation[24] : If there is a switch on the path which agent can not use -> 1
         # observation[25] : If there is a switch on the path which agent can not use -> 1
         # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
+        # observation[27] : If there the agent can only walk forward or stop -> 1
 
         observation = np.zeros(self.observation_dim)
         visited = []
@@ -313,18 +305,21 @@ class FastTreeObs(ObservationBuilder):
                     observation[14 + dir_loop] = has_opp_agent
                     observation[18 + dir_loop] = has_same_agent
                     observation[22 + dir_loop] = has_target
+                    observation[26 + dir_loop] = int(np.math.isinf(new_cell_dist))
+
+            agents_on_switch, \
+            agents_near_to_switch, \
+            agents_near_to_switch_all, \
+            agents_on_switch_all = \
+                self.check_agent_decision(agent_virtual_position, agent.direction)
+            observation[7] = int(agents_on_switch)
+            observation[8] = int(agents_near_to_switch)
+            observation[9] = int(agents_near_to_switch_all)
+
+            action = self.dead_lock_avoidance_agent.act([handle], 0.0)
+            observation[30] = int(action == RailEnvActions.STOP_MOVING)
+            observation[31] = int(fast_count_nonzero(possible_transitions) == 1)
 
-        agents_on_switch, \
-        agents_near_to_switch, \
-        agents_near_to_switch_all, \
-        agents_on_switch_all = \
-            self.check_agent_decision(agent_virtual_position, agent.direction)
-        observation[7] = int(agents_on_switch)
-        observation[8] = int(agents_near_to_switch)
-        observation[9] = int(agents_near_to_switch_all)
-
-        action = self.dead_lock_avoidance_agent.act([handle], 0.0)
-        observation[26] = int(action == RailEnvActions.STOP_MOVING)
         self.env.dev_obs_dict.update({handle: visited})
 
         return observation
-- 
GitLab