From aad421906712ab9b9bbd0990a6a836290b143dba Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 24 May 2019 14:48:06 +0200
Subject: [PATCH] changed paths to work correctly

---
 torch_training/dueling_double_dqn.py  | 2 +-
 torch_training/training_navigation.py | 6 +++---
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py
index c2ff1a7..3b98a3a 100644
--- a/torch_training/dueling_double_dqn.py
+++ b/torch_training/dueling_double_dqn.py
@@ -8,7 +8,7 @@ import torch
 import torch.nn.functional as F
 import torch.optim as optim
 
-from baselines.torch_training.model import QNetwork, QNetwork2
+from model import QNetwork, QNetwork2
 
 BUFFER_SIZE = int(1e5)  # replay buffer size
 BATCH_SIZE = 512  # minibatch size
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 8736e3d..b30278c 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -4,7 +4,7 @@ from collections import deque
 import numpy as np
 import torch
 
-from baselines.torch_training.dueling_double_dqn import Agent
+from dueling_double_dqn import Agent
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
@@ -46,7 +46,7 @@ env = RailEnv(width=20,
               number_of_agents=3)
 
 """
-env_renderer = RenderTool(env, gl="QTSVG")
+env_renderer = RenderTool(env, gl="QT")
 handle = env.get_agent_handles()
 
 state_size = 105 * 2
@@ -66,7 +66,7 @@ action_prob = [0] * 4
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./baselines/torch_training/Nets/avoid_checkpoint15000.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
 demo = True
 
-- 
GitLab