From 73919673f7d31938cd60b4a24dd3dcaf38f8b1b7 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 22 Jun 2019 12:20:40 -0500
Subject: [PATCH] minor test of multi speed implementation

---
 torch_training/training_navigation.py | 27 +++++++++++++++++----------
 1 file changed, 17 insertions(+), 10 deletions(-)

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index c043372..634db69 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -40,21 +40,22 @@ env = RailEnv(width=15,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
               number_of_agents=1)
 
-
+"""
 env = RailEnv(width=10,
               height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 env.load("./railway/complex_scene.pkl")
+file_load = True
 """
 
-env = RailEnv(width=12,
-              height=12,
-              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
               obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-              number_of_agents=5)
-
+              number_of_agents=15)
+file_load = False
 env.reset(True, True)
-
-env_renderer = RenderTool(env, gl="PILSVG")
+"""
+env_renderer = RenderTool(env, gl="PILSVG",)
 handle = env.get_agent_handles()
 
 state_size = 168 * 2
@@ -78,6 +79,7 @@ agent = Agent(state_size, action_size, "FC", 0)
 agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
 demo = True
+record_images = False
 
 def max_lt(seq, val):
     """
@@ -129,7 +131,10 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset(True, True)
+    if file_load :
+        obs = env.reset(False, False)
+    else:
+        obs = env.reset(True, True)
     if demo:
         env_renderer.set_new_rail()
     final_obs = obs.copy()
@@ -154,13 +159,15 @@ for trials in range(1, n_trials + 1):
     for step in range(max_steps):
         if demo:
             env_renderer.renderEnv(show=True, show_observations=False)
+            if record_images:
+                env_renderer.gl.saveImage("./Images/frame_{:04d}.bmp".format(step))
         # print(step)
         # Action
         for a in range(env.get_num_agents()):
             if demo:
                 eps = 0
             # action = agent.act(np.array(obs[a]), eps=eps)
-            action = agent.act(agent_obs[a], eps=eps)
+            action = 2 #agent.act(agent_obs[a], eps=eps)
             action_prob[action] += 1
             action_dict.update({a: action})
         # Environment step
-- 
GitLab