diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index eb11fb1b183d7041e35186f40693b378c80eec68..4524d784b0bd8b9e6a67970be36192b94984a090 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -1,16 +1,17 @@ -import random from collections import deque +from sys import path import matplotlib.pyplot as plt import numpy as np +import random import torch from dueling_double_dqn import Agent + from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool - from utils.observation_utils import norm_obs_clip, split_tree random.seed(1) @@ -62,7 +63,8 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth')) +with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in: + agent.qnetwork_local.load_state_dict(torch.load(file_in)) demo = True record_images = False diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 3007ce59bed2e75adc0672511e397c39eea1e2cb..987c7d645ae9d574d1427fc0c686e740abdb449c 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -1,3 +1,5 @@ +from sys import path + import random from collections import deque @@ -87,7 +89,9 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth')) +with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in: + agent.qnetwork_local.load_state_dict(torch.load(file_in)) + demo = True record_images = False