From 60b730d019313dade48abc04a784bff3fb3e37ef Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 9 Jul 2019 11:50:19 +0200
Subject: [PATCH] #42 run baselines in ci

---
 requirements_torch_training.txt        |  2 ++
 torch_training/multi_agent_training.py |  2 +-
 torch_training/training_navigation.py  | 21 ++++++++-------------
 3 files changed, 11 insertions(+), 14 deletions(-)

diff --git a/requirements_torch_training.txt b/requirements_torch_training.txt
index 03f6383..d8c4a46 100644
--- a/requirements_torch_training.txt
+++ b/requirements_torch_training.txt
@@ -1,2 +1,4 @@
 git+https://gitlab.aicrowd.com/flatland/flatland.git@master
+importlib-metadata>=0.17
+importlib_resources>=1.0.2
 torch>=1.1.0
\ No newline at end of file
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index ba9f144..0037f95 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -1,11 +1,11 @@
 from collections import deque
-from sys import path
 
 import matplotlib.pyplot as plt
 import numpy as np
 import random
 import torch
 from dueling_double_dqn import Agent
+from importlib_resources import path
 
 import torch_training.Nets
 from flatland.envs.generators import complex_rail_generator
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 1857b67..c52a891 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -1,12 +1,11 @@
-from sys import path
-
-import random
 from collections import deque
 
 import matplotlib.pyplot as plt
 import numpy as np
+import random
 import torch
 from dueling_double_dqn import Agent
+from importlib_resources import path
 
 import torch_training.Nets
 from flatland.envs.generators import complex_rail_generator
@@ -14,7 +13,6 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-
 from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(1)
@@ -70,7 +68,7 @@ file_load = False
 
 """
 observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
-env_renderer = RenderTool(env, gl="PILSVG",)
+env_renderer = RenderTool(env, gl="PILSVG", )
 handle = env.get_agent_handles()
 features_per_node = 9
 state_size = features_per_node * 85 * 2
@@ -94,11 +92,9 @@ agent = Agent(state_size, action_size, "FC", 0)
 with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
-
 demo = True
 record_images = False
 
-
 for trials in range(1, n_trials + 1):
 
     if trials % 50 == 0 and not demo:
@@ -136,7 +132,7 @@ for trials in range(1, n_trials + 1):
         agent_data = np.clip(agent_data, -1, 1)
         obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
         agent_data = env.agents[a]
-        speed = 1 #np.random.randint(1,5)
+        speed = 1  # np.random.randint(1,5)
         agent_data.speed_data['speed'] = 1. / speed
 
     for i in range(2):
@@ -145,7 +141,6 @@ for trials in range(1, n_trials + 1):
     for a in range(env.get_num_agents()):
         agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
 
-
     score = 0
     env_done = 0
     # Run episode
@@ -206,10 +201,10 @@ for trials in range(1, n_trials + 1):
     print(
         '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
             env.get_num_agents(), x_dim, y_dim,
-              trials,
-              np.mean(scores_window),
-              100 * np.mean(done_window),
-              eps, action_prob / np.sum(action_prob)), end=" ")
+            trials,
+            np.mean(scores_window),
+            100 * np.mean(done_window),
+            eps, action_prob / np.sum(action_prob)), end=" ")
 
     if trials % 100 == 0:
         print(
-- 
GitLab