From 1f8dfa7131ad5fc5a857ff080a1e4ce5e67fe159 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 12 Jun 2019 18:36:48 +0200 Subject: [PATCH] #57 access resources for torch_training from resources; initial setup tox --- RLLib_training/RailEnvRLLibWrapper.py | 25 ++++++++++---------- RLLib_training/train.py | 2 +- RLLib_training/train_experiment.py | 33 +++++++-------------------- torch_training/training_navigation.py | 4 ++-- 4 files changed, 24 insertions(+), 40 deletions(-) diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index 57fe38e..4cba2f3 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -1,9 +1,9 @@ -from flatland.envs.rail_env import RailEnv +import numpy as np from ray.rllib.env.multi_agent_env import MultiAgentEnv -from flatland.envs.observations import TreeObsForRailEnv from ray.rllib.utils.seed import seed as set_seed + from flatland.envs.generators import complex_rail_generator, random_rail_generator -import numpy as np +from flatland.envs.rail_env import RailEnv class RailEnvRLLibWrapper(MultiAgentEnv): @@ -20,20 +20,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv): if config['rail_generator'] == "complex_rail_generator": self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5, - nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index)) + nr_extra=config['nr_extra'], + seed=config['seed'] * (1 + vector_index)) elif config['rail_generator'] == "random_rail_generator": self.rail_generator = random_rail_generator() elif config['rail_generator'] == "load_env": self.predefined_env = True else: - raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}') + raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}') - set_seed(config['seed'] * (1+vector_index)) + set_seed(config['seed'] * (1 + vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], - number_of_agents=config["number_of_agents"], - obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator, - prediction_builder_object=config['predictor']) + number_of_agents=config["number_of_agents"], + obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator, + prediction_builder_object=config['predictor']) if self.predefined_env: self.env.load(config['load_env_path']) @@ -190,8 +191,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): elif collision_info[1] == 0: # In this case, the other agent (agent 2) was on the same cell at t-1 # There is a collision if agent 2 is at t, on the cell where was agent 1 at t-1 - coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset-1, 0] + \ - 1000 * pred_pos[agent_handle, time_offset, 1] + coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset - 1, 0] + \ + 1000 * pred_pos[agent_handle, time_offset, 1] coord_agent_2_t = coord_other_agents[collision_info[0], 1] if coord_agent_1_t_minus_1 == coord_agent_2_t: pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1 @@ -200,7 +201,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): # In this case, the other agent (agent 2) will be on the same cell at t+1 # There is a collision if agent 2 is at t, on the cell where will be agent 1 at t+1 coord_agent_1_t_plus_1 = pred_pos[agent_handle, time_offset + 1, 0] + \ - 1000 * pred_pos[agent_handle, time_offset, 1] + 1000 * pred_pos[agent_handle, time_offset, 1] coord_agent_2_t = coord_other_agents[collision_info[0], 1] if coord_agent_1_t_plus_1 == coord_agent_2_t: pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1 diff --git a/RLLib_training/train.py b/RLLib_training/train.py index 5d07c8b..1546205 100644 --- a/RLLib_training/train.py +++ b/RLLib_training/train.py @@ -4,13 +4,13 @@ import gym import numpy as np import ray import ray.rllib.agents.ppo.ppo as ppo +from RailEnvRLLibWrapper import RailEnvRLLibWrapper from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print from RLLib_training.custom_preprocessors import CustomPreprocessor -from RailEnvRLLibWrapper import RailEnvRLLibWrapper from flatland.envs.generators import complex_rail_generator ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor) diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index d7d0b4a..2853071 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -1,33 +1,19 @@ import os -import tempfile import gin import gym - -import gin - -from flatland.envs.generators import complex_rail_generator - -import ray from importlib_resources import path -from ray import tune # Import PPO trainer: we can replace these imports by any other trainer from RLLib. from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph from ray.rllib.models import ModelCatalog -from ray.rllib.utils.seed import seed as set_seed -from ray.tune.logger import pretty_print -from baselines.RLLib_training.custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor - -from baselines.RLLib_training.custom_models import ConvModelGlobalObs from flatland.envs.predictions import DummyPredictorForRailEnv -gin.external_configurable(DummyPredictorForRailEnv) +gin.external_configurable(DummyPredictorForRailEnv) import ray -import numpy as np from ray.tune.logger import UnifiedLogger from ray.tune.logger import pretty_print @@ -35,16 +21,13 @@ from ray.tune.logger import pretty_print from RailEnvRLLibWrapper import RailEnvRLLibWrapper from custom_models import ConvModelGlobalObs from custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor -from flatland.envs.generators import complex_rail_generator -from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \ - LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent import tempfile from ray import tune from ray.rllib.utils.seed import seed as set_seed -from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv,\ - LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent +from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \ + LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent gin.external_configurable(TreeObsForRailEnv) gin.external_configurable(GlobalObsForRailEnv) @@ -81,11 +64,13 @@ def train(config, reporter): # Observation space and action space definitions if isinstance(config["obs_builder"], TreeObsForRailEnv): if config['predictor'] is None: - obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), ) * config['step_memory']) + obs_space = gym.spaces.Tuple( + (gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),) * config['step_memory']) else: obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), - gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)), - gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) *config['step_memory']) + gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)), + gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[ + 'step_memory']) preprocessor = "tree_obs_prep" elif isinstance(config["obs_builder"], GlobalObsForRailEnv): @@ -120,7 +105,6 @@ def train(config, reporter): else: raise ValueError("Undefined observation space") - act_space = gym.spaces.Discrete(5) # Dict with the different policies to train @@ -190,7 +174,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae, predictor, step_memory): - tune.run( train, name=name, diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index e673941..1fbe149 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -5,7 +5,6 @@ from collections import deque import numpy as np import torch -from dueling_double_dqn import Agent from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -74,10 +73,11 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth')) +# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth')) demo = False + def max_lt(seq, val): """ Return greatest item in seq for which item < val applies. -- GitLab