Skip to content
Snippets Groups Projects
Commit 32d2635f authored by hagrid67's avatar hagrid67
Browse files

butchered Player to remove all trace of nn agent and torch

parent 280dd97e
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import complex_rail_generator
from flatland.utils.rendertools import RenderTool
from flatland.baselines.dueling_double_dqn import Agent
# from flatland.baselines.dueling_double_dqn import Agent
from collections import deque
import torch
# import torch
import random
import numpy as np
import time
......@@ -26,7 +26,7 @@ class Player(object):
self.scores = []
self.dones_list = []
self.action_prob = [0]*4
self.agent = Agent(self.state_size, self.action_size, "FC", 0)
# self.agent = Agent(self.state_size, self.action_size, "FC", 0)
# self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
#self.agent.qnetwork_local.load_state_dict(torch.load(
# '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
......@@ -72,11 +72,12 @@ class Player(object):
next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1)
# Update replay buffer and train agent
for handle in self.env.get_agent_handles():
self.agent.step(self.obs[handle], self.action_dict[handle],
all_rewards[handle], next_obs[handle], done[handle],
train=False)
self.score += all_rewards[handle]
if False:
for handle in self.env.get_agent_handles():
self.agent.step(self.obs[handle], self.action_dict[handle],
all_rewards[handle], next_obs[handle], done[handle],
train=False)
self.score += all_rewards[handle]
self.iFrame += 1
......@@ -263,8 +264,8 @@ def main_old(render=True, delay=0.0):
np.mean(scores_window),
100 * np.mean(done_window),
eps, rFps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
# torch.save(agent.qnetwork_local.state_dict(),
# '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1]*4
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment