From 32d2635f126e3a55502a6bf383df3cfb6f409cb9 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Fri, 17 May 2019 13:08:17 +0100
Subject: [PATCH] butchered Player to remove all trace of nn agent and torch

---
 examples/play_model.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index e502bd2..5ed6feb 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,9 +1,9 @@
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.generators import complex_rail_generator
 from flatland.utils.rendertools import RenderTool
-from flatland.baselines.dueling_double_dqn import Agent
+# from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
-import torch
+# import torch
 import random
 import numpy as np
 import time
@@ -26,7 +26,7 @@ class Player(object):
         self.scores = []
         self.dones_list = []
         self.action_prob = [0]*4
-        self.agent = Agent(self.state_size, self.action_size, "FC", 0)
+        # self.agent = Agent(self.state_size, self.action_size, "FC", 0)
         # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
         #self.agent.qnetwork_local.load_state_dict(torch.load(
         #    '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
@@ -72,11 +72,12 @@ class Player(object):
             next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1)
 
         # Update replay buffer and train agent
-        for handle in self.env.get_agent_handles():
-            self.agent.step(self.obs[handle], self.action_dict[handle],
-                all_rewards[handle], next_obs[handle], done[handle],
-                train=False)
-            self.score += all_rewards[handle]
+        if False:
+            for handle in self.env.get_agent_handles():
+                self.agent.step(self.obs[handle], self.action_dict[handle],
+                    all_rewards[handle], next_obs[handle], done[handle],
+                    train=False)
+                self.score += all_rewards[handle]
 
         self.iFrame += 1
 
@@ -263,8 +264,8 @@ def main_old(render=True, delay=0.0):
                    np.mean(scores_window),
                    100 * np.mean(done_window),
                    eps, rFps, action_prob / np.sum(action_prob)))
-            torch.save(agent.qnetwork_local.state_dict(),
-                    '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
+            # torch.save(agent.qnetwork_local.state_dict(),
+            #         '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
             action_prob = [1]*4
 
 
-- 
GitLab