From 2ba03dda01d83f4b1eb2d83e1b760d3a58a10bf3 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 18:27:32 -0400
Subject: [PATCH] updating training to handle stochastic events and different
 agent speeds. Trying to create feasible sarsa packages

---
 torch_training/multi_agent_training.py  |  3 ++-
 torch_training/render_agent_behavior.py |  2 +-
 torch_training/training_navigation.py   | 19 +++++++++++++++----
 3 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index ba30d46..ad42e0a 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -46,7 +46,7 @@ def main(argv):
     # Parameters for the Environment
     x_dim = 20
     y_dim = 20
-    n_agents = 5
+    n_agents = 3
     tree_depth = 2
 
     # Use a the malfunction generator to break agents from time to time
@@ -163,6 +163,7 @@ def main(argv):
         # different times during an episode
         final_obs = agent_obs.copy()
         final_obs_next = agent_next_obs.copy()
+        register_action_state = np.zeros(env.get_num_agents(), dtype=bool)
 
         # Build agent specific observations
         for a in range(env.get_num_agents()):
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index f41cbb9..aabd457 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -102,7 +102,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-with path(torch_training.Nets, "navigator_checkpoint1100.pth") as file_in:
+with path(torch_training.Nets, "navigator_checkpoint9600.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 2417746..8ec2b08 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -87,7 +87,7 @@ def main(argv):
 
     # We set the number of episodes we would like to train on
     if 'n_trials' not in locals():
-        n_trials = 6000
+        n_trials = 15000
 
     # And the max number of steps we want to take per episode
     max_steps = int(3 * (env.height + env.width))
@@ -107,6 +107,9 @@ def main(argv):
     action_prob = [0] * action_size
     agent_obs = [None] * env.get_num_agents()
     agent_next_obs = [None] * env.get_num_agents()
+    agent_obs_buffer = [None] * env.get_num_agents()
+    agent_action_buffer = [None] * env.get_num_agents()
+    cummulated_reward = np.zeros(env.get_num_agents())
 
     # Now we load a Double dueling DQN agent
     agent = Agent(state_size, action_size, "FC", 0)
@@ -146,15 +149,23 @@ def main(argv):
             # Build agent specific observations and normalize
             for a in range(env.get_num_agents()):
                 agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
-
+                cummulated_reward[a] += all_rewards[a]
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
                 if done[a]:
                     final_obs[a] = agent_obs[a].copy()
                     final_obs_next[a] = agent_next_obs[a].copy()
                     final_action_dict.update({a: action_dict[a]})
-                if not done[a] and register_action_state[a]:
-                    agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
+                if not done[a]:
+                    if agent_obs_buffer[a] is not None and register_action_state[a]:
+                        agent_delayed_next = agent_obs[a].copy()
+                        agent.step(agent_obs_buffer[a], agent_action_buffer[a], cummulated_reward[a],
+                                   agent_delayed_next, done[a])
+                        cummulated_reward[a] = 0.
+                    if register_action_state[a]:
+                        agent_obs_buffer[a] = agent_obs[a].copy()
+                        agent_action_buffer[a] = action_dict[a]
+
                 score += all_rewards[a] / env.get_num_agents()
 
             # Copy observation
-- 
GitLab