Skip to content
Snippets Groups Projects
Commit 1f8dfa71 authored by u214892's avatar u214892
Browse files

#57 access resources for torch_training from resources; initial setup tox

parent 5ea7f884
No related branches found
No related tags found
1 merge request!157 access resources through importlib resources
from flatland.envs.rail_env import RailEnv
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.envs.observations import TreeObsForRailEnv
from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.generators import complex_rail_generator, random_rail_generator
import numpy as np
from flatland.envs.rail_env import RailEnv
class RailEnvRLLibWrapper(MultiAgentEnv):
......@@ -20,20 +20,21 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
if config['rail_generator'] == "complex_rail_generator":
self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index))
nr_extra=config['nr_extra'],
seed=config['seed'] * (1 + vector_index))
elif config['rail_generator'] == "random_rail_generator":
self.rail_generator = random_rail_generator()
elif config['rail_generator'] == "load_env":
self.predefined_env = True
else:
raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
set_seed(config['seed'] * (1+vector_index))
set_seed(config['seed'] * (1 + vector_index))
self.env = RailEnv(width=config["width"], height=config["height"],
number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
prediction_builder_object=config['predictor'])
number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
prediction_builder_object=config['predictor'])
if self.predefined_env:
self.env.load(config['load_env_path'])
......@@ -190,8 +191,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
elif collision_info[1] == 0:
# In this case, the other agent (agent 2) was on the same cell at t-1
# There is a collision if agent 2 is at t, on the cell where was agent 1 at t-1
coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset-1, 0] + \
1000 * pred_pos[agent_handle, time_offset, 1]
coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset - 1, 0] + \
1000 * pred_pos[agent_handle, time_offset, 1]
coord_agent_2_t = coord_other_agents[collision_info[0], 1]
if coord_agent_1_t_minus_1 == coord_agent_2_t:
pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
......@@ -200,7 +201,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# In this case, the other agent (agent 2) will be on the same cell at t+1
# There is a collision if agent 2 is at t, on the cell where will be agent 1 at t+1
coord_agent_1_t_plus_1 = pred_pos[agent_handle, time_offset + 1, 0] + \
1000 * pred_pos[agent_handle, time_offset, 1]
1000 * pred_pos[agent_handle, time_offset, 1]
coord_agent_2_t = coord_other_agents[collision_info[0], 1]
if coord_agent_1_t_plus_1 == coord_agent_2_t:
pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
......
......@@ -4,13 +4,13 @@ import gym
import numpy as np
import ray
import ray.rllib.agents.ppo.ppo as ppo
from RailEnvRLLibWrapper import RailEnvRLLibWrapper
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
from RLLib_training.custom_preprocessors import CustomPreprocessor
from RailEnvRLLibWrapper import RailEnvRLLibWrapper
from flatland.envs.generators import complex_rail_generator
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
......
import os
import tempfile
import gin
import gym
import gin
from flatland.envs.generators import complex_rail_generator
import ray
from importlib_resources import path
from ray import tune
# Import PPO trainer: we can replace these imports by any other trainer from RLLib.
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.seed import seed as set_seed
from ray.tune.logger import pretty_print
from baselines.RLLib_training.custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
from baselines.RLLib_training.custom_models import ConvModelGlobalObs
from flatland.envs.predictions import DummyPredictorForRailEnv
gin.external_configurable(DummyPredictorForRailEnv)
gin.external_configurable(DummyPredictorForRailEnv)
import ray
import numpy as np
from ray.tune.logger import UnifiedLogger
from ray.tune.logger import pretty_print
......@@ -35,16 +21,13 @@ from ray.tune.logger import pretty_print
from RailEnvRLLibWrapper import RailEnvRLLibWrapper
from custom_models import ConvModelGlobalObs
from custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \
LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
import tempfile
from ray import tune
from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv,\
LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \
LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
gin.external_configurable(TreeObsForRailEnv)
gin.external_configurable(GlobalObsForRailEnv)
......@@ -81,11 +64,13 @@ def train(config, reporter):
# Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv):
if config['predictor'] is None:
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), ) * config['step_memory'])
obs_space = gym.spaces.Tuple(
(gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),) * config['step_memory'])
else:
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) *config['step_memory'])
gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
'step_memory'])
preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
......@@ -120,7 +105,6 @@ def train(config, reporter):
else:
raise ValueError("Undefined observation space")
act_space = gym.spaces.Discrete(5)
# Dict with the different policies to train
......@@ -190,7 +174,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
predictor, step_memory):
tune.run(
train,
name=name,
......
......@@ -5,7 +5,6 @@ from collections import deque
import numpy as np
import torch
from dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
......@@ -74,10 +73,11 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
demo = False
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment