From 1e092263a2d65104a29cf8bada806629a83c0529 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Fri, 11 Dec 2020 21:54:47 +0100
Subject: [PATCH] :-)

---
 .../multi_agent_training.py                   | 84 +++----------------
 reinforcement_learning/ppo_agent.py           | 46 +++++-----
 run.py                                        | 18 ++--
 utils/deadlock_check.py                       | 44 ++++++++++
 utils/fast_tree_obs.py                        | 13 ++-
 5 files changed, 102 insertions(+), 103 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 01add0f..cc6cfce 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -22,6 +22,7 @@ from torch.utils.tensorboard import SummaryWriter
 
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 from reinforcement_learning.ppo_agent import PPOAgent
+from utils.deadlock_check import get_agent_positions, check_for_deadlock
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -77,42 +78,6 @@ def create_rail_env(env_params, tree_observation):
         random_seed=seed
     )
 
-
-def get_agent_positions(env):
-    agent_positions: np.ndarray = np.full((env.height, env.width), -1)
-    for agent_handle in env.get_agent_handles():
-        agent = env.agents[agent_handle]
-        if agent.status == RailAgentStatus.ACTIVE:
-            position = agent.position
-            if position is None:
-                position = agent.initial_position
-            agent_positions[position] = agent_handle
-    return agent_positions
-
-
-def check_for_dealock(handle, env, agent_positions):
-    agent = env.agents[handle]
-    if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
-        return False
-
-    position = agent.position
-    if position is None:
-        position = agent.initial_position
-    possible_transitions = env.rail.get_transitions(*position, agent.direction)
-    num_transitions = fast_count_nonzero(possible_transitions)
-    for dir_loop in range(4):
-        if possible_transitions[dir_loop] == 1:
-            new_position = get_new_position(position, dir_loop)
-            opposite_agent = agent_positions[new_position]
-            if opposite_agent != handle and opposite_agent != -1:
-                num_transitions -= 1
-            else:
-                return False
-
-    is_deadlock = num_transitions <= 0
-    return is_deadlock
-
-
 def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     # Environment parameters
     n_agents = train_env_params.n_agents
@@ -207,7 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, action_size, train_params)
-    if True:
+    if False:
         policy = PPOAgent(state_size, action_size)
     # Load existing policy
     if train_params.load_policy is not "":
@@ -256,7 +221,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Reset environment
         reset_timer.start()
-        number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
+        number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 500)))
         train_env_params.n_agents = episode_idx % number_of_agents + 1
 
         train_env = create_rail_env(train_env_params, tree_observation)
@@ -318,34 +283,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 agent_positions = get_agent_positions(train_env)
                 for agent_handle in train_env.get_agent_handles():
                     agent = train_env.agents[agent_handle]
-
                     act = action_dict.get(agent_handle, RailEnvActions.MOVE_FORWARD)
                     if agent.status == RailAgentStatus.ACTIVE:
-                        pos = agent.position
-                        dir = agent.direction
-                        possible_transitions = train_env.rail.get_transitions(*pos, dir)
-                        num_transitions = fast_count_nonzero(possible_transitions)
-                        if act == RailEnvActions.STOP_MOVING:
-                            all_rewards[agent_handle] -= 2.0
-
-                        if num_transitions == 1:
-                            if act != RailEnvActions.MOVE_FORWARD:
-                                all_rewards[agent_handle] -= 1.0
-                        if check_for_dealock(agent_handle, train_env, agent_positions):
-                            all_rewards[agent_handle] -= 5.0
-                    elif agent.status == RailAgentStatus.READY_TO_DEPART:
-                        all_rewards[agent_handle] -= 5.0
-            else:
-                if False:
-                    agent_positions = get_agent_positions(train_env)
-                    for agent_handle in train_env.get_agent_handles():
-                        agent = train_env.agents[agent_handle]
-                        act = action_dict.get(agent_handle, RailEnvActions.MOVE_FORWARD)
-                        if agent.status == RailAgentStatus.ACTIVE:
-                            if done[agent_handle] == False:
-                                if check_for_dealock(agent_handle, train_env, agent_positions):
-                                    all_rewards[agent_handle] -= 1000.0
-                                    done[agent_handle] = True
+                        if done[agent_handle] == False:
+                            if check_for_deadlock(agent_handle, train_env, agent_positions):
+                                all_rewards[agent_handle] -= 1000.0
 
             step_timer.end()
 
@@ -559,17 +501,17 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=25000, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0,
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12000, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
                         type=int)
     parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
                         type=int)
     parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
-    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=2000, type=int)
-    parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
-    parser.add_argument("--eps_end", help="min exploration", default=0.05, type=float)
-    parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float)
-    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int)
+    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
+    parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
+    parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.99975, type=float)
+    parser.add_argument("--buffer_size", help="replay buffer size", default=int(32_000), type=int)
     parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
     parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
     parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
index e603e70..e97b265 100644
--- a/reinforcement_learning/ppo_agent.py
+++ b/reinforcement_learning/ppo_agent.py
@@ -9,10 +9,12 @@ from torch.distributions import Categorical
 # Hyperparameters
 from reinforcement_learning.policy import Policy
 
-device = torch.device("cpu")  # "cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 print("device:", device)
 
 
+# https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html
+
 class DataBuffers:
     def __init__(self):
         self.reset()
@@ -44,7 +46,7 @@ class ActorCriticModel(nn.Module):
             nn.Tanh(),
             nn.Linear(hidsize2, action_size),
             nn.Softmax(dim=-1)
-        )
+        ).to(device)
 
         self.critic = nn.Sequential(
             nn.Linear(state_size, hidsize1),
@@ -52,7 +54,7 @@ class ActorCriticModel(nn.Module):
             nn.Linear(hidsize1, hidsize2),
             nn.Tanh(),
             nn.Linear(hidsize2, 1)
-        )
+        ).to(device)
 
     def forward(self, x):
         raise NotImplementedError
@@ -95,11 +97,11 @@ class PPOAgent(Policy):
         super(PPOAgent, self).__init__()
 
         # parameters
-        self.learning_rate = 0.1e-4
-        self.gamma = 0.99
-        self.surrogate_eps_clip = 0.2
-        self.K_epoch = 30
-        self.weight_loss = 1.0
+        self.learning_rate = 1.0e-5
+        self.gamma = 0.95
+        self.surrogate_eps_clip = 0.1
+        self.K_epoch = 50
+        self.weight_loss = 0.5
         self.weight_entropy = 0.01
 
         # objects
@@ -144,8 +146,8 @@ class PPOAgent(Policy):
                 discounted_reward = 0
                 done_list.insert(0, 1)
             else:
-                discounted_reward = reward_i + self.gamma * discounted_reward
                 done_list.insert(0, 0)
+            discounted_reward = reward_i + self.gamma * discounted_reward
             reward_list.insert(0, discounted_reward)
             state_next_list.insert(0, state_next_i)
             prob_a_list.insert(0, prob_action_i)
@@ -160,22 +162,21 @@ class PPOAgent(Policy):
             torch.tensor(prob_a_list).to(device)
 
         # standard-normalize rewards
-        rewards = (rewards - rewards.mean()) / (rewards.std() + 1.e-5)
+        # rewards = (rewards - rewards.mean()) / (rewards.std() + 1.e-5)
 
         return states, actions, rewards, states_next, dones, prob_actions
 
     def train_net(self):
-        # Optimize policy for K epochs:
-        for _ in range(self.K_epoch):
-            # All agents have to propagate their experiences made during past episode
-            for handle in range(len(self.memory)):
-                # Extract agent's episode history (list of all transitions)
-                agent_episode_history = self.memory.get_transitions(handle)
-                if len(agent_episode_history) > 0:
-                    # Convert the replay buffer to torch tensors (arrays)
-                    states, actions, rewards, states_next, dones, probs_action = \
-                        self._convert_transitions_to_torch_tensors(agent_episode_history)
-
+        # All agents have to propagate their experiences made during past episode
+        for handle in range(len(self.memory)):
+            # Extract agent's episode history (list of all transitions)
+            agent_episode_history = self.memory.get_transitions(handle)
+            if len(agent_episode_history) > 0:
+                # Convert the replay buffer to torch tensors (arrays)
+                states, actions, rewards, states_next, dones, probs_action = \
+                    self._convert_transitions_to_torch_tensors(agent_episode_history)
+                # Optimize policy for K epochs:
+                for _ in range(int(self.K_epoch)):
                     # Evaluating actions (actor) and values (critic)
                     logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)
 
@@ -201,8 +202,9 @@ class PPOAgent(Policy):
                     self.optimizer.step()
 
                     # Transfer the current loss to the agents loss (information) for debug purpose only
-                    self.loss = loss.mean().detach().numpy()
+                    self.loss = loss.mean().detach().cpu().numpy()
 
+        self.K_epoch = max(3, self.K_epoch - 0.01)
         # Reset all collect transition data
         self.memory.reset()
 
diff --git a/run.py b/run.py
index 1757f75..c4608bf 100644
--- a/run.py
+++ b/run.py
@@ -47,16 +47,18 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
 # Print per-step logs
 VERBOSE = True
 USE_FAST_TREEOBS = True
-USE_PPO_AGENT = True
+USE_PPO_AGENT = False
 
 # Checkpoint to use (remember to push it!)
-checkpoint = "./checkpoints/201124171810-7800.pth"  # 18.249244799876152 DEPTH=2 AGENTS=10
-# checkpoint = "./checkpoints/201126150143-5200.pth"  # 18.249244799876152 DEPTH=2 AGENTS=10
-# checkpoint = "./checkpoints/201126160144-2000.pth"  # 18.249244799876152 DEPTH=2 AGENTS=10
-checkpoint = "./checkpoints/201127160352-2000.pth"
-checkpoint = "./checkpoints/201130083154-2000.pth"
-
-EPSILON = 0.005
+checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
+# checkpoint = "./checkpoints/201126150143-5200.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
+# checkpoint = "./checkpoints/201126160144-2000.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
+checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786
+checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857
+checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504
+checkpoint = "./checkpoints/201211164554-8900.pth" # DDDQN: 17.44397192482364
+
+EPSILON = 0.01
 
 # Use last action cache
 USE_ACTION_CACHE = False
diff --git a/utils/deadlock_check.py b/utils/deadlock_check.py
index 6d414fa..d787c8c 100644
--- a/utils/deadlock_check.py
+++ b/utils/deadlock_check.py
@@ -1,5 +1,49 @@
+import numpy as np
+
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import fast_count_nonzero
+
+
+def get_agent_positions(env):
+    agent_positions: np.ndarray = np.full((env.height, env.width), -1)
+    for agent_handle in env.get_agent_handles():
+        agent = env.agents[agent_handle]
+        if agent.status == RailAgentStatus.ACTIVE:
+            position = agent.position
+            if position is None:
+                position = agent.initial_position
+            agent_positions[position] = agent_handle
+    return agent_positions
+
+
+def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None):
+    agent = env.agents[handle]
+    if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
+        return False
+
+    position = agent.position
+    if position is None:
+        position = agent.initial_position
+    if check_position is not None:
+        position = check_position
+    direction = agent.direction
+    if check_direction is not None:
+        direction = check_direction
+
+    possible_transitions = env.rail.get_transitions(*position, direction)
+    num_transitions = fast_count_nonzero(possible_transitions)
+    for dir_loop in range(4):
+        if possible_transitions[dir_loop] == 1:
+            new_position = get_new_position(position, dir_loop)
+            opposite_agent = agent_positions[new_position]
+            if opposite_agent != handle and opposite_agent != -1:
+                num_transitions -= 1
+            else:
+                return False
+
+    is_deadlock = num_transitions <= 0
+    return is_deadlock
 
 
 def check_if_all_blocked(env):
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index b104916..c45477d 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -7,6 +7,7 @@ from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
 
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+from utils.deadlock_check import check_for_deadlock, get_agent_positions
 
 """
 LICENCE for the FastTreeObs Observation Builder  
@@ -25,7 +26,7 @@ class FastTreeObs(ObservationBuilder):
 
     def __init__(self, max_depth):
         self.max_depth = max_depth
-        self.observation_dim = 36
+        self.observation_dim = 41
 
     def build_data(self):
         if self.env is not None:
@@ -244,6 +245,7 @@ class FastTreeObs(ObservationBuilder):
 
     def get_many(self, handles: Optional[List[int]] = None):
         self.dead_lock_avoidance_agent.start_step(train=False)
+        self.agent_positions = get_agent_positions(self.env)
         observations = super().get_many(handles)
         self.dead_lock_avoidance_agent.end_step(train=False)
         return observations
@@ -328,6 +330,11 @@ class FastTreeObs(ObservationBuilder):
                     observation[19 + dir_loop] = has_same_agent
                     observation[23 + dir_loop] = has_target
                     observation[27 + dir_loop] = int(np.math.isinf(new_cell_dist))
+                    observation[36] = int(check_for_deadlock(handle,
+                                                             self.env,
+                                                             self.agent_positions,
+                                                             new_position,
+                                                             branch_direction))
 
             agents_on_switch, \
             agents_near_to_switch, \
@@ -341,7 +348,9 @@ class FastTreeObs(ObservationBuilder):
             observation[10] = int(agents_near_to_switch_all)
 
             action = self.dead_lock_avoidance_agent.act([handle], 0.0)
-            observation[31] = int(action == RailEnvActions.STOP_MOVING)
+            observation[35] = int(action == RailEnvActions.STOP_MOVING)
+
+            observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions))
 
         self.env.dev_obs_dict.update({handle: visited})
 
-- 
GitLab