From 5c8a88cf573c29a575cbdc9000a7ce78a2d2db49 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Wed, 11 Nov 2020 17:52:48 +0100
Subject: [PATCH] FastTreeObs (fix) -> 0.8773

---
 reinforcement_learning/dddqn_policy.py        | 14 ++++++-
 .../multi_agent_training.py                   | 40 ++++++++++---------
 reinforcement_learning/policy.py              |  5 ++-
 utils/fast_tree_obs.py                        |  4 +-
 4 files changed, 40 insertions(+), 23 deletions(-)

diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 6218ab8..0350a94 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -17,6 +17,7 @@ class DDDQNPolicy(Policy):
     """Dueling Double DQN policy"""
 
     def __init__(self, state_size, action_size, parameters, evaluation_mode=False):
+        self.parameters = parameters
         self.evaluation_mode = evaluation_mode
 
         self.state_size = state_size
@@ -59,11 +60,16 @@ class DDDQNPolicy(Policy):
         self.qnetwork_local.eval()
         with torch.no_grad():
             action_values = self.qnetwork_local(state)
+
         self.qnetwork_local.train()
 
         # Epsilon-greedy action selection
-        if random.random() > eps:
+        if random.random() >= eps:
             return np.argmax(action_values.cpu().data.numpy())
+            qvals = action_values.cpu().data.numpy()[0]
+            qvals = qvals - np.min(qvals)
+            qvals = qvals / (1e-5 + np.sum(qvals))
+            return np.argmax(np.random.multinomial(1, qvals))
         else:
             return random.choice(np.arange(self.action_size))
 
@@ -148,6 +154,12 @@ class DDDQNPolicy(Policy):
         self.act(np.array([[0] * self.state_size]))
         self._learn()
 
+    def clone(self):
+        me = DDDQNPolicy(self.state_size, self.action_size, self.parameters, evaluation_mode=True)
+        me.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        me.qnetwork_target = copy.deepcopy(self.qnetwork_target)
+        return me
+
 
 Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
 
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 0f26a20..1a1f512 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -19,7 +19,6 @@ from flatland.utils.rendertools import RenderTool
 from torch.utils.tensorboard import SummaryWriter
 
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -173,6 +172,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     completion_window = deque(maxlen=checkpoint_interval)
 
     # Double Dueling DQN policy
+    USE_SINGLE_AGENT_TRAINING = False
     policy = DDDQNPolicy(state_size, action_size, train_params)
     # policy = PPOAgent(state_size, action_size, n_agents)
     # Load existing policy
@@ -227,8 +227,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
         policy.reset()
 
-        policy2 = DeadLockAvoidanceAgent(train_env)
-        policy2.reset()
+        if episode_idx % 100 == 0:
+            policy2 = policy.clone()
 
         reset_timer.end()
 
@@ -253,9 +253,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         max_steps = train_env._max_episode_steps
 
         # Run episode
-        agent_to_learn = 0
+        agent_to_learn = [0]
         if train_env.get_num_agents() > 1:
-            agent_to_learn = np.random.choice(train_env.get_num_agents())
+            agent_to_learn = np.unique(np.random.choice(train_env.get_num_agents(), train_env.get_num_agents()))
+        # agent_to_learn = np.arange(train_env.get_num_agents())
+
         for step in range(max_steps - 1):
             inference_timer.start()
             policy.start_step()
@@ -263,11 +265,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
             for agent in train_env.get_agent_handles():
                 if info['action_required'][agent]:
                     update_values[agent] = True
-
-                    if agent == agent_to_learn or True:
+                    if agent in agent_to_learn or not USE_SINGLE_AGENT_TRAINING:
                         action = policy.act(agent_obs[agent], eps=eps_start)
                     else:
-                        action = policy2.act([agent], eps=eps_start)
+                        action = policy2.act(agent_obs[agent], eps=eps_start)
                     action_count[action] += 1
                     actions_taken.append(action)
                 else:
@@ -316,7 +317,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 if update_values[agent] or done['__all__']:
                     # Only learn from timesteps where somethings happened
                     learn_timer.start()
-                    if agent == agent_to_learn:
+                    if agent in agent_to_learn:
                         policy.step(agent,
                                     agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
                                     agent_obs[agent],
@@ -507,27 +508,28 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 if __name__ == "__main__":
     parser = ArgumentParser()
     parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
-    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
+                        type=int)
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
                         type=int)
-    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
-    parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
+    parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
-    parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float)
     parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int)
     parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
     parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
     parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
                         type=bool)
     parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
-    parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
-    parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
+    parser.add_argument("--gamma", help="discount factor", default=0.97, type=float)
+    parser.add_argument("--tau", help="soft update of target parameters", default=0.5e-3, type=float)
     parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
     parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
-    parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
-    parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
-    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
+    parser.add_argument("--update_every", help="how often to update the network", default=10, type=int)
+    parser.add_argument("--use_gpu", help="use GPU if available", default=True, type=bool)
+    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=4, type=int)
     parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
     parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index c7621a6..c7300de 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -24,4 +24,7 @@ class Policy:
         pass
 
     def reset(self):
-        pass
\ No newline at end of file
+        pass
+
+    def clone(self):
+        return self
\ No newline at end of file
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index db22a8f..0666ef4 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -222,8 +222,8 @@ class FastTreeObs(ObservationBuilder):
                                                         dir_loop,
                                                         depth + 1)
                         visited.append(v)
-                        has_opp_agent = max(has_opp_agent, hoa)
-                        has_same_agent = max(has_same_agent, hsa)
+                        has_opp_agent += hoa * 2 ** (-1 - depth)
+                        has_same_agent += hsa * 2 ** (-1 - depth)
                         has_target = max(has_target, ht)
                 return has_opp_agent, has_same_agent, has_target, visited
             else:
-- 
GitLab