From 1f8dfa7131ad5fc5a857ff080a1e4ce5e67fe159 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 12 Jun 2019 18:36:48 +0200
Subject: [PATCH] #57 access resources for torch_training from resources;
 initial setup tox

---
 RLLib_training/RailEnvRLLibWrapper.py | 25 ++++++++++----------
 RLLib_training/train.py               |  2 +-
 RLLib_training/train_experiment.py    | 33 +++++++--------------------
 torch_training/training_navigation.py |  4 ++--
 4 files changed, 24 insertions(+), 40 deletions(-)

diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index 57fe38e..4cba2f3 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -1,9 +1,9 @@
-from flatland.envs.rail_env import RailEnv
+import numpy as np
 from ray.rllib.env.multi_agent_env import MultiAgentEnv
-from flatland.envs.observations import TreeObsForRailEnv
 from ray.rllib.utils.seed import seed as set_seed
+
 from flatland.envs.generators import complex_rail_generator, random_rail_generator
-import numpy as np
+from flatland.envs.rail_env import RailEnv
 
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
@@ -20,20 +20,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
 
         if config['rail_generator'] == "complex_rail_generator":
             self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
-                                                          nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index))
+                                                         nr_extra=config['nr_extra'],
+                                                         seed=config['seed'] * (1 + vector_index))
         elif config['rail_generator'] == "random_rail_generator":
             self.rail_generator = random_rail_generator()
         elif config['rail_generator'] == "load_env":
             self.predefined_env = True
 
         else:
-            raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
+            raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
 
-        set_seed(config['seed'] * (1+vector_index))
+        set_seed(config['seed'] * (1 + vector_index))
         self.env = RailEnv(width=config["width"], height=config["height"],
-                number_of_agents=config["number_of_agents"],
-                obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
-                prediction_builder_object=config['predictor'])
+                           number_of_agents=config["number_of_agents"],
+                           obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
+                           prediction_builder_object=config['predictor'])
 
         if self.predefined_env:
             self.env.load(config['load_env_path'])
@@ -190,8 +191,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
                 elif collision_info[1] == 0:
                     # In this case, the other agent (agent 2) was on the same cell at t-1
                     # There is a collision if agent 2 is at t, on the cell where was agent 1 at t-1
-                    coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset-1, 0] + \
-                                          1000 * pred_pos[agent_handle, time_offset, 1]
+                    coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset - 1, 0] + \
+                                              1000 * pred_pos[agent_handle, time_offset, 1]
                     coord_agent_2_t = coord_other_agents[collision_info[0], 1]
                     if coord_agent_1_t_minus_1 == coord_agent_2_t:
                         pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
@@ -200,7 +201,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
                     # In this case, the other agent (agent 2) will be on the same cell at t+1
                     # There is a collision if agent 2 is at t, on the cell where will be agent 1 at t+1
                     coord_agent_1_t_plus_1 = pred_pos[agent_handle, time_offset + 1, 0] + \
-                                              1000 * pred_pos[agent_handle, time_offset, 1]
+                                             1000 * pred_pos[agent_handle, time_offset, 1]
                     coord_agent_2_t = coord_other_agents[collision_info[0], 1]
                     if coord_agent_1_t_plus_1 == coord_agent_2_t:
                         pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
diff --git a/RLLib_training/train.py b/RLLib_training/train.py
index 5d07c8b..1546205 100644
--- a/RLLib_training/train.py
+++ b/RLLib_training/train.py
@@ -4,13 +4,13 @@ import gym
 import numpy as np
 import ray
 import ray.rllib.agents.ppo.ppo as ppo
+from RailEnvRLLibWrapper import RailEnvRLLibWrapper
 from ray.rllib.agents.ppo.ppo import PPOTrainer
 from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
 from ray.rllib.models import ModelCatalog
 from ray.tune.logger import pretty_print
 
 from RLLib_training.custom_preprocessors import CustomPreprocessor
-from RailEnvRLLibWrapper import RailEnvRLLibWrapper
 from flatland.envs.generators import complex_rail_generator
 
 ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index d7d0b4a..2853071 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -1,33 +1,19 @@
 import os
-import tempfile
 
 import gin
 import gym
-
-import gin
-
-from flatland.envs.generators import complex_rail_generator
-
-import ray
 from importlib_resources import path
-from ray import tune
 # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
 from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
 from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
 from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
 from ray.rllib.models import ModelCatalog
-from ray.rllib.utils.seed import seed as set_seed
-from ray.tune.logger import pretty_print
-from baselines.RLLib_training.custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
-
-from baselines.RLLib_training.custom_models import ConvModelGlobalObs
 
 from flatland.envs.predictions import DummyPredictorForRailEnv
-gin.external_configurable(DummyPredictorForRailEnv)
 
+gin.external_configurable(DummyPredictorForRailEnv)
 
 import ray
-import numpy as np
 
 from ray.tune.logger import UnifiedLogger
 from ray.tune.logger import pretty_print
@@ -35,16 +21,13 @@ from ray.tune.logger import pretty_print
 from RailEnvRLLibWrapper import RailEnvRLLibWrapper
 from custom_models import ConvModelGlobalObs
 from custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
-from flatland.envs.generators import complex_rail_generator
-from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \
-    LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
 import tempfile
 
 from ray import tune
 
 from ray.rllib.utils.seed import seed as set_seed
-from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv,\
-                                       LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
+from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \
+    LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
 
 gin.external_configurable(TreeObsForRailEnv)
 gin.external_configurable(GlobalObsForRailEnv)
@@ -81,11 +64,13 @@ def train(config, reporter):
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
         if config['predictor'] is None:
-            obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), ) * config['step_memory'])
+            obs_space = gym.spaces.Tuple(
+                (gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),) * config['step_memory'])
         else:
             obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
-                                        gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
-                                        gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) *config['step_memory'])
+                                          gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
+                                          gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
+                                             'step_memory'])
         preprocessor = "tree_obs_prep"
 
     elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
@@ -120,7 +105,6 @@ def train(config, reporter):
     else:
         raise ValueError("Undefined observation space")
 
-
     act_space = gym.spaces.Discrete(5)
 
     # Dict with the different policies to train
@@ -190,7 +174,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                    map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
                    entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
                    predictor, step_memory):
-
     tune.run(
         train,
         name=name,
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index e673941..1fbe149 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -5,7 +5,6 @@ from collections import deque
 import numpy as np
 import torch
 
-from dueling_double_dqn import Agent
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
@@ -74,10 +73,11 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
+# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
 
 demo = False
 
+
 def max_lt(seq, val):
     """
     Return greatest item in seq for which item < val applies.
-- 
GitLab