From eb71fa51a3e04e48bac45f7a2a48ae5fb141dbe4 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 9 Jul 2019 11:22:52 +0200
Subject: [PATCH] #42 run baselines in ci

---
 torch_training/multi_agent_training.py | 3 ++-
 torch_training/training_navigation.py  | 4 +++-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 4524d78..216ddc6 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -7,6 +7,7 @@ import random
 import torch
 from dueling_double_dqn import Agent
 
+import torch_training
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -63,7 +64,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
+with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 demo = True
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 987c7d6..095f20b 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
 import numpy as np
 import torch
 from dueling_double_dqn import Agent
+
+import torch_training
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -89,7 +91,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
+with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 
-- 
GitLab