An error occurred while loading the file. Please try again.
Forked from
Flatland / baselines
305 commits behind the upstream repository.
grid_search_train.py 5.02 KiB
from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper
import gym
from flatland.envs.generators import complex_rail_generator
import ray.rllib.agents.ppo.ppo as ppo
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
from ray.rllib.models.preprocessors import Preprocessor
import ray
import numpy as np
import gin
from ray import tune
class MyPreprocessorClass(Preprocessor):
def _init_shape(self, obs_space, options):
return (105,)
def transform(self, observation):
return observation # return the preprocessed observation
ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
ray.init()
def train(config, reporter):
print('Init Env')
transition_probability = [15, # empty cell - Case 0
5, # Case 1 - straight
5, # Case 2 - simple switch
1, # Case 3 - diamond crossing
1, # Case 4 - single slip
1, # Case 5 - double slip
1, # Case 6 - symmetrical
0, # Case 7 - dead end
1, # Case 1b (8) - simple turn right
1, # Case 1c (9) - simple turn left
1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail
"""
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env_config = {"width":config['map_width'],
"height":config['map_height'],
"rail_generator":complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
"number_of_agents":config['n_agents']}
"""
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../notebooks/temp.npy']),
number_of_agents=3)
"""
# Example generate a random rail
# env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'],
# rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
# number_of_agents=config["n_agents"])
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4)
# Dict with the different policies to train
policy_graphs = {
f"ppo_policy": (PPOPolicyGraph, obs_space, act_space, {})
}
def policy_mapping_fn(agent_id):
return f"ppo_policy"
agent_config = ppo.DEFAULT_CONFIG.copy()
agent_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
agent_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = config['horizon']
# agent_config["num_workers"] = 0
# agent_config["num_cpus_per_worker"] = 10
# agent_config["num_gpus"] = 0.5
# agent_config["num_gpus_per_worker"] = 0.5
# agent_config["num_cpus_for_driver"] = 1
# agent_config["num_envs_per_worker"] = 10
agent_config["env_config"] = env_config
agent_config["batch_mode"] = "complete_episodes"
ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config)
for i in range(100000 + 2):
print("== Iteration", i, "==")
print("-- PPO --")
print(pretty_print(ppo_trainer.train()))
if i % config['save_every'] == 0:
checkpoint = ppo_trainer.save()
print("checkpoint saved at", checkpoint)
reporter(num_iterations_trained=ppo_trainer._iteration)
@gin.configurable
def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, local_dir):
tune.run(
train,
name=name,
stop={"num_iterations_trained": num_iterations},
config={"n_agents": n_agents,
"hidden_sizes": hidden_sizes, # Array containing the sizes of the network layers
"save_every": save_every,
"map_width": map_width,
"map_height": map_height,
"local_dir": local_dir,
"horizon": horizon # Max number of time steps
},
resources_per_trial={
"cpu": 11,
"gpu": 0.5
},
local_dir=local_dir
)
if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
dir = 'baselines/grid_search_configs/n_agents_grid_search'
gin.parse_config_file(dir + '/config.gin')
run_grid_search(local_dir=dir)