From 8e79b68f17891c2a889203b6ac886ac46f0bbbbd Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 15 May 2019 14:18:50 +0200
Subject: [PATCH] Updated training: New state has two time frames

---
 examples/training_navigation.py | 35 ++++++++++++++++++++++-----------
 1 file changed, 24 insertions(+), 11 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 0262640..9d45cd1 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -45,7 +45,7 @@ env = RailEnv(width=20,
 env_renderer = RenderTool(env, gl="QT")
 handle = env.get_agent_handles()
 
-state_size = 105
+state_size = 105 * 2
 action_size = 4
 n_trials = 15000
 eps = 1.
@@ -55,13 +55,16 @@ action_dict = dict()
 final_action_dict = dict()
 scores_window = deque(maxlen=100)
 done_window = deque(maxlen=100)
+time_obs = deque(maxlen=2)
 scores = []
 dones_list = []
 action_prob = [0] * 4
+agent_obs = [None] * env.get_num_agents()
+agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
 
-demo = True
+demo = False
 
 
 def max_lt(seq, val):
@@ -103,11 +106,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
     max_obs = max(1, max_lt(obs, 1000))
     min_obs = max(0, min_lt(obs, 0))
     if max_obs == min_obs:
-        return np.clip(np.array(obs)/ max_obs, clip_min, clip_max)
+        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
     norm = np.abs(max_obs - min_obs)
     if norm == 0:
         norm = 1.
-    return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max)
+    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
 
 
 for trials in range(1, n_trials + 1):
@@ -115,13 +118,18 @@ for trials in range(1, n_trials + 1):
     # Reset environment
     obs = env.reset()
     final_obs = obs.copy()
-    final_obs_next =  obs.copy()
+    final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
         data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
         obs[a] = np.concatenate((data, distance))
+
+    for i in range(2):
+        time_obs.append(obs)
     # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+    for a in range(env.get_num_agents()):
+        agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
 
     score = 0
     env_done = 0
@@ -134,7 +142,8 @@ for trials in range(1, n_trials + 1):
         for a in range(env.get_num_agents()):
             if demo:
                 eps = 0
-            action = agent.act(np.array(obs[a]), eps=eps)
+            # action = agent.act(np.array(obs[a]), eps=eps)
+            action = agent.act(agent_obs[a])
             action_prob[action] += 1
             action_dict.update({a: action})
 
@@ -148,17 +157,21 @@ for trials in range(1, n_trials + 1):
             distance = norm_obs_clip(distance)
             next_obs[a] = np.concatenate((data, distance))
 
+        time_obs.append(next_obs)
+
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
+            agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+
             if done[a]:
-                final_obs[a] = obs[a].copy()
-                final_obs_next[a] = next_obs[a].copy()
+                final_obs[a] = agent_obs[a].copy()
+                final_obs_next[a] = agent_next_obs[a].copy()
                 final_action_dict.update({a: action_dict[a]})
             if not demo and not done[a]:
-                agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
+                agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
             score += all_rewards[a]
 
-        obs = next_obs.copy()
+        agent_obs = agent_next_obs.copy()
         if done['__all__']:
             env_done = 1
             for a in range(env.get_num_agents()):
-- 
GitLab