diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py index c2ff1a7f1f8502926efa88c932c31bff1a2ed179..3b98a3a62a5a6b9e1cd1b4732b46831d5dfee95d 100644 --- a/torch_training/dueling_double_dqn.py +++ b/torch_training/dueling_double_dqn.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from baselines.torch_training.model import QNetwork, QNetwork2 +from model import QNetwork, QNetwork2 BUFFER_SIZE = int(1e5) # replay buffer size BATCH_SIZE = 512 # minibatch size diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 8736e3d8e665e521e2b34de9136488925782ee28..b30278cf4865088382a352dd1ea9e242b8ef0567 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -4,7 +4,7 @@ from collections import deque import numpy as np import torch -from baselines.torch_training.dueling_double_dqn import Agent +from dueling_double_dqn import Agent from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -46,7 +46,7 @@ env = RailEnv(width=20, number_of_agents=3) """ -env_renderer = RenderTool(env, gl="QTSVG") +env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() state_size = 105 * 2 @@ -66,7 +66,7 @@ action_prob = [0] * 4 agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./baselines/torch_training/Nets/avoid_checkpoint15000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) demo = True