From 054bc93a4336d831a5ce96e2519d82d56bb53fff Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 9 Jul 2019 11:18:21 +0200
Subject: [PATCH] #42 run baselines in ci

---
 torch_training/multi_agent_training.py | 8 +++++---
 torch_training/training_navigation.py  | 6 +++++-
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index eb11fb1..4524d78 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -1,16 +1,17 @@
-import random
 from collections import deque
+from sys import path
 
 import matplotlib.pyplot as plt
 import numpy as np
+import random
 import torch
 from dueling_double_dqn import Agent
+
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-
 from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(1)
@@ -62,7 +63,8 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
+with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
+    agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 demo = True
 record_images = False
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 3007ce5..987c7d6 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -1,3 +1,5 @@
+from sys import path
+
 import random
 from collections import deque
 
@@ -87,7 +89,9 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
+with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
+    agent.qnetwork_local.load_state_dict(torch.load(file_in))
+
 
 demo = True
 record_images = False
-- 
GitLab