From d68164dba751c0696b48c0b6d9533a7918138118 Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume@iccluster091.iccluster.epfl.ch> Date: Wed, 15 May 2019 17:14:27 +0200 Subject: [PATCH] some changes for more convenient trainer modification --- grid_search_train.py | 77 +++++++++++++++++--------------------------- 1 file changed, 30 insertions(+), 47 deletions(-) diff --git a/grid_search_train.py b/grid_search_train.py index 1d06bea..b30cb40 100644 --- a/grid_search_train.py +++ b/grid_search_train.py @@ -4,9 +4,11 @@ import gym from flatland.envs.generators import complex_rail_generator -import ray.rllib.agents.ppo.ppo as ppo -from ray.rllib.agents.ppo.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph + +# Import PPO trainer: we can replace these imports by any other trainer from RLLib. +from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG +from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer +from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print @@ -27,6 +29,7 @@ from ray import tune ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor) ray.init(object_store_memory=150000000000) + def train(config, reporter): print('Init Env') @@ -42,66 +45,47 @@ def train(config, reporter): 1, # Case 1c (9) - simple turn left 1] # Case 2b (10) - simple switch mirrored - # Example generate a random rail - """ - env = RailEnv(width=10, - height=10, - rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=1) - """ + # Example configuration to generate a random rail env_config = {"width":config['map_width'], "height":config['map_height'], "rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0), "number_of_agents":config['n_agents']} - """ - env = RailEnv(width=20, - height=20, - rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - ['../notebooks/temp.npy']), - number_of_agents=3) - - """ - - - - # Example generate a random rail - # env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'], - # rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12), - # number_of_agents=config["n_agents"]) + # Observation space and action space definitions obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) act_space = gym.spaces.Discrete(4) # Dict with the different policies to train policy_graphs = { - config['policy_folder_name'].format(**locals()): (PPOPolicyGraph, obs_space, act_space, {}) + config['policy_folder_name'].format(**locals()): (PolicyGraph, obs_space, act_space, {}) } def policy_mapping_fn(agent_id): return config['policy_folder_name'].format(**locals()) - agent_config = ppo.DEFAULT_CONFIG.copy() - agent_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"} - agent_config['multiagent'] = {"policy_graphs": policy_graphs, + + # Trainer configuration + trainer_config = DEFAULT_CONFIG.copy() + trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"} + trainer_config['multiagent'] = {"policy_graphs": policy_graphs, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": list(policy_graphs.keys())} - agent_config["horizon"] = config['horizon'] - - agent_config["num_workers"] = 0 - agent_config["num_cpus_per_worker"] = 10 - agent_config["num_gpus"] = 0.5 - agent_config["num_gpus_per_worker"] = 0.5 - agent_config["num_cpus_for_driver"] = 2 - agent_config["num_envs_per_worker"] = 10 - agent_config["env_config"] = env_config - agent_config["batch_mode"] = "complete_episodes" - agent_config['simple_optimizer'] = False + trainer_config["horizon"] = config['horizon'] + + trainer_config["num_workers"] = 0 + trainer_config["num_cpus_per_worker"] = 10 + trainer_config["num_gpus"] = 0.5 + trainer_config["num_gpus_per_worker"] = 0.5 + trainer_config["num_cpus_for_driver"] = 2 + trainer_config["num_envs_per_worker"] = 10 + trainer_config["env_config"] = env_config + trainer_config["batch_mode"] = "complete_episodes" + trainer_config['simple_optimizer'] = False def logger_creator(conf): """Creates a Unified logger with a default logdir prefix containing the agent name and the env id """ - print("FOLDER", config['policy_folder_name']) logdir = config['policy_folder_name'].format(**locals()) logdir = tempfile.mkdtemp( prefix=logdir, dir=config['local_dir']) @@ -109,19 +93,18 @@ def train(config, reporter): logger = logger_creator - ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config, logger_creator=logger) + trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config, logger_creator=logger) for i in range(100000 + 2): print("== Iteration", i, "==") - print("-- PPO --") - print(pretty_print(ppo_trainer.train())) + print(pretty_print(trainer.train())) if i % config['save_every'] == 0: - checkpoint = ppo_trainer.save() + checkpoint = trainer.save() print("checkpoint saved at", checkpoint) - reporter(num_iterations_trained=ppo_trainer._iteration) + reporter(num_iterations_trained=trainer._iteration) @gin.configurable @@ -151,6 +134,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every, if __name__ == '__main__': gin.external_configurable(tune.grid_search) - dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search' + dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search' # To Modify gin.parse_config_file(dir + '/config.gin') run_grid_search(local_dir=dir) -- GitLab