From aad421906712ab9b9bbd0990a6a836290b143dba Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 24 May 2019 14:48:06 +0200 Subject: [PATCH] changed paths to work correctly --- torch_training/dueling_double_dqn.py | 2 +- torch_training/training_navigation.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py index c2ff1a7..3b98a3a 100644 --- a/torch_training/dueling_double_dqn.py +++ b/torch_training/dueling_double_dqn.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from baselines.torch_training.model import QNetwork, QNetwork2 +from model import QNetwork, QNetwork2 BUFFER_SIZE = int(1e5) # replay buffer size BATCH_SIZE = 512 # minibatch size diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 8736e3d..b30278c 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -4,7 +4,7 @@ from collections import deque import numpy as np import torch -from baselines.torch_training.dueling_double_dqn import Agent +from dueling_double_dqn import Agent from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -46,7 +46,7 @@ env = RailEnv(width=20, number_of_agents=3) """ -env_renderer = RenderTool(env, gl="QTSVG") +env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() state_size = 105 * 2 @@ -66,7 +66,7 @@ action_prob = [0] * 4 agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./baselines/torch_training/Nets/avoid_checkpoint15000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) demo = True -- GitLab