diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 4524d784b0bd8b9e6a67970be36192b94984a090..216ddc66c439027fdf4a9603f5dc7a4091c246bf 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -7,6 +7,7 @@ import random import torch from dueling_double_dqn import Agent +import torch_training from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -63,7 +64,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) demo = True diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 987c7d645ae9d574d1427fc0c686e740abdb449c..095f20b4a91c0cab1e97fdec55cf419d3cd77be7 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -7,6 +7,8 @@ import matplotlib.pyplot as plt import numpy as np import torch from dueling_double_dqn import Agent + +import torch_training from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -89,7 +91,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in))