Skip to content
Snippets Groups Projects
Commit 0bd7cad8 authored by gmollard's avatar gmollard
Browse files

Initial commit

parents
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv):
def __init__(self,
width,
height,
rail_generator=random_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(RailEnvRLLibWrapper, self).__init__(width, height, rail_generator,
number_of_agents, obs_builder_object)
def reset(self, regen_rail=True, replace_agents=True):
self.agents_done = []
return super(RailEnvRLLibWrapper, self).reset(regen_rail, replace_agents)
def step(self, action_dict):
obs, rewards, dones, infos = super(RailEnvRLLibWrapper, self).step(action_dict)
d = dict()
r = dict()
o = dict()
# print(self.agents_done)
# print(dones)
for agent, done in dones.items():
if agent not in self.agents_done:
if agent != '__all__':
o[agent] = obs[agent]
r[agent] = rewards[agent]
d[agent] = dones[agent]
# obs.pop(agent_done)
# rewards.pop(agent_done)
# dones.pop(agent_done)
for agent, done in dones.items():
if done and agent != '__all__':
self.agents_done.append(agent)
return o, r, d, infos
def get_agent_handles(self):
return super(RailEnvRLLibWrapper, self).get_agent_handles()
File added
run_grid_search.name = "n_agents_results"
run_grid_search.num_iterations = 1002
run_grid_search.hidden_sizes = [32, 32]
run_grid_search.map_width = 15
run_grid_search.map_height = 15
run_grid_search.n_agents = {"grid_search": [1, 2, 3, 4]}
run_grid_search.horizon = 50
from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper
import random
import gym
from flatland.envs.generators import complex_rail_generator
import ray.rllib.agents.ppo.ppo as ppo
from ray.rllib.agents.ppo.ppo import PPOAgent
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
from ray.rllib.models.preprocessors import Preprocessor
import ray
import numpy as np
import gin
from ray import tune
class MyPreprocessorClass(Preprocessor):
def _init_shape(self, obs_space, options):
return (105,)
def transform(self, observation):
return observation # return the preprocessed observation
ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
ray.init()
def train(config, reporter):
print('Init Env')
env_name = f"rail_env_{config['n_agents']}" # To modify if different environments configs are explored.
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [0.5, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end
# Example generate a random rail
env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'],
rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
number_of_agents=config["n_agents"])
register_env(env_name, lambda _: env)
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4)
# Dict with the different policies to train
policy_graphs = {
f"ppo_policy": (PPOPolicyGraph, obs_space, act_space, {})
}
def policy_mapping_fn(agent_id):
return f"ppo_policy"
agent_config = ppo.DEFAULT_CONFIG.copy()
agent_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
agent_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = config['horizon']
ppo_trainer = PPOAgent(env=env_name, config=agent_config)
for i in range(100000 + 2):
print("== Iteration", i, "==")
print("-- PPO --")
print(pretty_print(ppo_trainer.train()))
if i % config['save_every'] == 0:
checkpoint = ppo_trainer.save()
print("checkpoint saved at", checkpoint)
reporter(num_iterations_trained=ppo_trainer._iteration)
@gin.configurable
def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, local_dir):
tune.run(
train,
name=name,
stop={"num_iterations_trained": num_iterations},
config={"n_agents": n_agents,
"hidden_sizes": hidden_sizes, # Array containing the sizes of the network layers
"save_every": save_every,
"map_width": map_width,
"map_height": map_height,
"local_dir": local_dir,
"horizon": horizon # Max number of time steps
},
resources_per_trial={
"cpu": 11,
"gpu": 0.5
},
local_dir=local_dir
)
if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
dir = 'grid_search_configs/n_agents_grid_search'
gin.parse_config_file(dir + '/config.gin')
run_grid_search(local_dir=dir)
train.py 0 → 100644
from flatland.envs import rail_env
from flatland.envs.rail_env import random_rail_generator
from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper
from flatland.utils.rendertools import RenderTool
import random
import gym
import matplotlib.pyplot as plt
from flatland.envs.generators import complex_rail_generator
import ray.rllib.agents.ppo.ppo as ppo
import ray.rllib.agents.dqn.dqn as dqn
from ray.rllib.agents.ppo.ppo import PPOAgent
from ray.rllib.agents.dqn.dqn import DQNAgent
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
from ray.rllib.models.preprocessors import Preprocessor
import ray
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
# RailEnv.__bases__ = (RailEnv.__bases__[0], MultiAgentEnv)
class MyPreprocessorClass(Preprocessor):
def _init_shape(self, obs_space, options):
return (105,)
def transform(self, observation):
return observation # return the preprocessed observation
ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
ray.init()
def train(config):
print('Init Env')
random.seed(1)
np.random.seed(1)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [0.5, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end
# Example generate a random rail
env = RailEnvRLLibWrapper(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=1, nr_extra=20, min_dist=12),
number_of_agents=1)
register_env("railenv", lambda _: env)
# if config['render']:
# env_renderer = RenderTool(env, gl="QT")
# plt.figure(figsize=(5,5))
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4)
# Dict with the different policies to train
policy_graphs = {
"ppo_policy": (PPOPolicyGraph, obs_space, act_space, {})
}
def policy_mapping_fn(agent_id):
return f"ppo_policy"
agent_config = ppo.DEFAULT_CONFIG.copy()
agent_config['model'] = {"fcnet_hiddens": [32, 32], "custom_preprocessor": "my_prep"}
agent_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = 50
ppo_trainer = PPOAgent(env=f"railenv_", config=agent_config)
for i in range(100000 + 2):
print("== Iteration", i, "==")
print("-- PPO --")
print(pretty_print(ppo_trainer.train()))
# if i % config['save_every'] == 0:
# checkpoint = ppo_trainer.save()
# print("checkpoint saved at", checkpoint)
train({})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment