From b42ef74e64bf32df121196208b5bbc95e51994ec Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Tue, 23 Apr 2019 16:39:10 +0200
Subject: [PATCH] updated training for navigation

---
 examples/training_navigation.py | 36 +++++++++++++++++++++++----------
 1 file changed, 25 insertions(+), 11 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index eddb907..231d4e9 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -20,8 +20,8 @@ transition_probability = [10.0,  # empty cell - Case 0
                           0.0]  # Case 7 - dead end
 
 # Example generate a random rail
-env = RailEnv(width=7,
-              height=7,
+env = RailEnv(width=5,
+              height=5,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
               number_of_agents=1)
 env_renderer = RenderTool(env)
@@ -29,7 +29,7 @@ handle = env.get_agent_handles()
 
 state_size = 105
 action_size = 4
-n_trials = 5000
+n_trials = 9999
 eps = 1.
 eps_end = 0.005
 eps_decay = 0.998
@@ -40,14 +40,27 @@ scores = []
 dones_list = []
 action_prob = [0]*4
 agent = Agent(state_size, action_size, "FC", 0)
+agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint8000.pth'))
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+
+    idx = len(seq)-1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] > 0:
+            return seq[idx]
+        idx -= 1
+    return None
 
 for trials in range(1, n_trials + 1):
 
     # Reset environment
     obs = env.reset()
     for a in range(env.number_of_agents):
-        if np.max(obs[a]) > 0 and np.max(obs[a]) < np.inf:
-            obs[a] = np.clip(obs[a] / np.max(obs[a]), -1, 1)
+        norm = max(1, max_lt(obs[a],np.inf))
+        obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
 
     # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
 
@@ -55,21 +68,21 @@ for trials in range(1, n_trials + 1):
     env_done = 0
 
     # Run episode
-    for step in range(100):
+    for step in range(50):
         #if trials > 114:
-        #    env_renderer.renderEnv(show=True)
-
+        #env_renderer.renderEnv(show=True)
+        #print(step)
         # Action
         for a in range(env.number_of_agents):
-            action = agent.act(np.array(obs[a]), eps=eps)
+            action = agent.act(np.array(obs[a]), eps=0)
             action_prob[action] += 1
             action_dict.update({a: action})
 
         # Environment step
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.number_of_agents):
-            if np.max(next_obs[a]) > 0 and np.max(next_obs[a]) < np.inf:
-                next_obs[a] = np.clip(next_obs[a] / np.max(next_obs[a]), -1, 1)
+            norm = max(1, max_lt(next_obs[a], np.inf))
+            next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
         # Update replay buffer and train agent
         for a in range(env.number_of_agents):
             agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
@@ -108,3 +121,4 @@ for trials in range(1, n_trials + 1):
                 eps, action_prob / np.sum(action_prob)))
         torch.save(agent.qnetwork_local.state_dict(),
                    '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
+
-- 
GitLab