From 3af1729d0c13087af68cedd9cccf3371817f9ad6 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Mon, 20 May 2019 17:27:50 +0100
Subject: [PATCH] fixed lint, commented out some untouched Agent / pytorch refs

---
 examples/play_model.py | 34 +++++++++++++++++++---------------
 1 file changed, 19 insertions(+), 15 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index 777f2d34..bbbcfffd 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -26,9 +26,11 @@ class Player(object):
         self.scores = []
         self.dones_list = []
         self.action_prob = [0]*4
+
+        # Removing refs to a real agent for now.
         # self.agent = Agent(self.state_size, self.action_size, "FC", 0)
         # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
-        #self.agent.qnetwork_local.load_state_dict(torch.load(
+        # self.agent.qnetwork_local.load_state_dict(torch.load(
         #    '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
 
         self.iFrame = 0
@@ -56,8 +58,11 @@ class Player(object):
 
         # Pass the (stored) observation to the agent network and retrieve the action
         for handle in env.get_agent_handles():
+            # Real Agent
             # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
+            # Random actions
             action = random.randint(0, 3)
+            # Numpy version uses single random sequence
             # action = np.random.randint(0, 4, size=1)
             self.action_prob[action] += 1
             self.action_dict.update({handle: action})
@@ -65,7 +70,6 @@ class Player(object):
         # Environment step - pass the agent actions to the environment,
         # retrieve the response - observations, rewards, dones
         next_obs, all_rewards, done, _ = self.env.step(self.action_dict)
-        next_obs = next_obs
         
         for handle in env.get_agent_handles():
             norm = max(1, max_lt(next_obs[handle], np.inf))
@@ -117,7 +121,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
 
     for trials in range(1, n_trials + 1):
 
-        # Reset environment8
+        # Reset environment
         oPlayer.reset()
         env_renderer.set_new_rail()
 
@@ -156,8 +160,6 @@ def main_old(render=True, delay=0.0):
         env_renderer = RenderTool(env, gl="QTSVG")
         # env_renderer = RenderTool(env, gl="QT")
 
-    state_size = 105
-    action_size = 4
     n_trials = 9999
     eps = 1.
     eps_end = 0.005
@@ -168,7 +170,11 @@ def main_old(render=True, delay=0.0):
     scores = []
     dones_list = []
     action_prob = [0]*4
-    agent = Agent(state_size, action_size, "FC", 0)
+
+    # Real Agent
+    # state_size = 105
+    # action_size = 4
+    # agent = Agent(state_size, action_size, "FC", 0)
     # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
 
     def max_lt(seq, val):
@@ -188,12 +194,10 @@ def main_old(render=True, delay=0.0):
     tStart = time.time()
     for trials in range(1, n_trials + 1):
 
-        # Reset environment8
+        # Reset environment
         obs = env.reset()
         env_renderer.set_new_rail()
 
-        #obs = obs[0]
-
         for a in range(env.get_num_agents()):
             norm = max(1, max_lt(obs[a], np.inf))
             obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
@@ -210,13 +214,12 @@ def main_old(render=True, delay=0.0):
             # print(step)
             # Action
             for a in range(env.get_num_agents()):
-                action = random.randint(0,3)  # agent.act(np.array(obs[a]), eps=eps)
+                action = random.randint(0, 3)  # agent.act(np.array(obs[a]), eps=eps)
                 action_prob[action] += 1
                 action_dict.update({a: action})
 
             if render:
                 env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=action_dict)
-                #time.sleep(10)
                 if delay > 0:
                     time.sleep(delay)
 
@@ -224,15 +227,16 @@ def main_old(render=True, delay=0.0):
 
             # Environment step
             next_obs, all_rewards, done, _ = env.step(action_dict)
-            #next_obs = next_obs[0]
 
             for a in range(env.get_num_agents()):
                 norm = max(1, max_lt(next_obs[a], np.inf))
                 next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
+
             # Update replay buffer and train agent
-            for a in range(env.get_num_agents()):
-                agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
-                score += all_rewards[a]
+            # only needed for "real" agent
+            # for a in range(env.get_num_agents()):
+            #    agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
+            #    score += all_rewards[a]
 
             obs = next_obs.copy()
             if done['__all__']:
-- 
GitLab