Skip to content
Snippets Groups Projects
Commit d68164db authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

some changes for more convenient trainer modification

parent 4306c599
No related branches found
No related tags found
No related merge requests found
...@@ -4,9 +4,11 @@ import gym ...@@ -4,9 +4,11 @@ import gym
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
import ray.rllib.agents.ppo.ppo as ppo
from ray.rllib.agents.ppo.ppo import PPOTrainer # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print from ray.tune.logger import pretty_print
...@@ -27,6 +29,7 @@ from ray import tune ...@@ -27,6 +29,7 @@ from ray import tune
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
ray.init(object_store_memory=150000000000) ray.init(object_store_memory=150000000000)
def train(config, reporter): def train(config, reporter):
print('Init Env') print('Init Env')
...@@ -42,66 +45,47 @@ def train(config, reporter): ...@@ -42,66 +45,47 @@ def train(config, reporter):
1, # Case 1c (9) - simple turn left 1, # Case 1c (9) - simple turn left
1] # Case 2b (10) - simple switch mirrored 1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail # Example configuration to generate a random rail
"""
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env_config = {"width":config['map_width'], env_config = {"width":config['map_width'],
"height":config['map_height'], "height":config['map_height'],
"rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0), "rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0),
"number_of_agents":config['n_agents']} "number_of_agents":config['n_agents']}
"""
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../notebooks/temp.npy']),
number_of_agents=3)
"""
# Example generate a random rail
# env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'],
# rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
# number_of_agents=config["n_agents"])
# Observation space and action space definitions
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4) act_space = gym.spaces.Discrete(4)
# Dict with the different policies to train # Dict with the different policies to train
policy_graphs = { policy_graphs = {
config['policy_folder_name'].format(**locals()): (PPOPolicyGraph, obs_space, act_space, {}) config['policy_folder_name'].format(**locals()): (PolicyGraph, obs_space, act_space, {})
} }
def policy_mapping_fn(agent_id): def policy_mapping_fn(agent_id):
return config['policy_folder_name'].format(**locals()) return config['policy_folder_name'].format(**locals())
agent_config = ppo.DEFAULT_CONFIG.copy()
agent_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"} # Trainer configuration
agent_config['multiagent'] = {"policy_graphs": policy_graphs, trainer_config = DEFAULT_CONFIG.copy()
trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn, "policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())} "policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = config['horizon'] trainer_config["horizon"] = config['horizon']
agent_config["num_workers"] = 0 trainer_config["num_workers"] = 0
agent_config["num_cpus_per_worker"] = 10 trainer_config["num_cpus_per_worker"] = 10
agent_config["num_gpus"] = 0.5 trainer_config["num_gpus"] = 0.5
agent_config["num_gpus_per_worker"] = 0.5 trainer_config["num_gpus_per_worker"] = 0.5
agent_config["num_cpus_for_driver"] = 2 trainer_config["num_cpus_for_driver"] = 2
agent_config["num_envs_per_worker"] = 10 trainer_config["num_envs_per_worker"] = 10
agent_config["env_config"] = env_config trainer_config["env_config"] = env_config
agent_config["batch_mode"] = "complete_episodes" trainer_config["batch_mode"] = "complete_episodes"
agent_config['simple_optimizer'] = False trainer_config['simple_optimizer'] = False
def logger_creator(conf): def logger_creator(conf):
"""Creates a Unified logger with a default logdir prefix """Creates a Unified logger with a default logdir prefix
containing the agent name and the env id containing the agent name and the env id
""" """
print("FOLDER", config['policy_folder_name'])
logdir = config['policy_folder_name'].format(**locals()) logdir = config['policy_folder_name'].format(**locals())
logdir = tempfile.mkdtemp( logdir = tempfile.mkdtemp(
prefix=logdir, dir=config['local_dir']) prefix=logdir, dir=config['local_dir'])
...@@ -109,19 +93,18 @@ def train(config, reporter): ...@@ -109,19 +93,18 @@ def train(config, reporter):
logger = logger_creator logger = logger_creator
ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config, logger_creator=logger) trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config, logger_creator=logger)
for i in range(100000 + 2): for i in range(100000 + 2):
print("== Iteration", i, "==") print("== Iteration", i, "==")
print("-- PPO --") print(pretty_print(trainer.train()))
print(pretty_print(ppo_trainer.train()))
if i % config['save_every'] == 0: if i % config['save_every'] == 0:
checkpoint = ppo_trainer.save() checkpoint = trainer.save()
print("checkpoint saved at", checkpoint) print("checkpoint saved at", checkpoint)
reporter(num_iterations_trained=ppo_trainer._iteration) reporter(num_iterations_trained=trainer._iteration)
@gin.configurable @gin.configurable
...@@ -151,6 +134,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -151,6 +134,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__': if __name__ == '__main__':
gin.external_configurable(tune.grid_search) gin.external_configurable(tune.grid_search)
dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search' dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search' # To Modify
gin.parse_config_file(dir + '/config.gin') gin.parse_config_file(dir + '/config.gin')
run_grid_search(local_dir=dir) run_grid_search(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment