From 32d2635f126e3a55502a6bf383df3cfb6f409cb9 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Fri, 17 May 2019 13:08:17 +0100 Subject: [PATCH] butchered Player to remove all trace of nn agent and torch --- examples/play_model.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index e502bd2..5ed6feb 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,9 +1,9 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.generators import complex_rail_generator from flatland.utils.rendertools import RenderTool -from flatland.baselines.dueling_double_dqn import Agent +# from flatland.baselines.dueling_double_dqn import Agent from collections import deque -import torch +# import torch import random import numpy as np import time @@ -26,7 +26,7 @@ class Player(object): self.scores = [] self.dones_list = [] self.action_prob = [0]*4 - self.agent = Agent(self.state_size, self.action_size, "FC", 0) + # self.agent = Agent(self.state_size, self.action_size, "FC", 0) # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) #self.agent.qnetwork_local.load_state_dict(torch.load( # '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) @@ -72,11 +72,12 @@ class Player(object): next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1) # Update replay buffer and train agent - for handle in self.env.get_agent_handles(): - self.agent.step(self.obs[handle], self.action_dict[handle], - all_rewards[handle], next_obs[handle], done[handle], - train=False) - self.score += all_rewards[handle] + if False: + for handle in self.env.get_agent_handles(): + self.agent.step(self.obs[handle], self.action_dict[handle], + all_rewards[handle], next_obs[handle], done[handle], + train=False) + self.score += all_rewards[handle] self.iFrame += 1 @@ -263,8 +264,8 @@ def main_old(render=True, delay=0.0): np.mean(scores_window), 100 * np.mean(done_window), eps, rFps, action_prob / np.sum(action_prob))) - torch.save(agent.qnetwork_local.state_dict(), - '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') + # torch.save(agent.qnetwork_local.state_dict(), + # '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') action_prob = [1]*4 -- GitLab