From eb71fa51a3e04e48bac45f7a2a48ae5fb141dbe4 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 9 Jul 2019 11:22:52 +0200 Subject: [PATCH] #42 run baselines in ci --- torch_training/multi_agent_training.py | 3 ++- torch_training/training_navigation.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 4524d78..216ddc6 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -7,6 +7,7 @@ import random import torch from dueling_double_dqn import Agent +import torch_training from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -63,7 +64,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) demo = True diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 987c7d6..095f20b 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -7,6 +7,8 @@ import matplotlib.pyplot as plt import numpy as np import torch from dueling_double_dqn import Agent + +import torch_training from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -89,7 +91,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) -- GitLab