Skip to content
Snippets Groups Projects
Commit d68164db authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

some changes for more convenient trainer modification

parent 4306c599
No related branches found
No related tags found
No related merge requests found
......@@ -4,9 +4,11 @@ import gym
from flatland.envs.generators import complex_rail_generator
import ray.rllib.agents.ppo.ppo as ppo
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
# Import PPO trainer: we can replace these imports by any other trainer from RLLib.
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
......@@ -27,6 +29,7 @@ from ray import tune
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
ray.init(object_store_memory=150000000000)
def train(config, reporter):
print('Init Env')
......@@ -42,66 +45,47 @@ def train(config, reporter):
1, # Case 1c (9) - simple turn left
1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail
"""
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
# Example configuration to generate a random rail
env_config = {"width":config['map_width'],
"height":config['map_height'],
"rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0),
"number_of_agents":config['n_agents']}
"""
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../notebooks/temp.npy']),
number_of_agents=3)
"""
# Example generate a random rail
# env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'],
# rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
# number_of_agents=config["n_agents"])
# Observation space and action space definitions
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4)
# Dict with the different policies to train
policy_graphs = {
config['policy_folder_name'].format(**locals()): (PPOPolicyGraph, obs_space, act_space, {})
config['policy_folder_name'].format(**locals()): (PolicyGraph, obs_space, act_space, {})
}
def policy_mapping_fn(agent_id):
return config['policy_folder_name'].format(**locals())
agent_config = ppo.DEFAULT_CONFIG.copy()
agent_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
agent_config['multiagent'] = {"policy_graphs": policy_graphs,
# Trainer configuration
trainer_config = DEFAULT_CONFIG.copy()
trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = config['horizon']
agent_config["num_workers"] = 0
agent_config["num_cpus_per_worker"] = 10
agent_config["num_gpus"] = 0.5
agent_config["num_gpus_per_worker"] = 0.5
agent_config["num_cpus_for_driver"] = 2
agent_config["num_envs_per_worker"] = 10
agent_config["env_config"] = env_config
agent_config["batch_mode"] = "complete_episodes"
agent_config['simple_optimizer'] = False
trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 10
trainer_config["num_gpus"] = 0.5
trainer_config["num_gpus_per_worker"] = 0.5
trainer_config["num_cpus_for_driver"] = 2
trainer_config["num_envs_per_worker"] = 10
trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes"
trainer_config['simple_optimizer'] = False
def logger_creator(conf):
"""Creates a Unified logger with a default logdir prefix
containing the agent name and the env id
"""
print("FOLDER", config['policy_folder_name'])
logdir = config['policy_folder_name'].format(**locals())
logdir = tempfile.mkdtemp(
prefix=logdir, dir=config['local_dir'])
......@@ -109,19 +93,18 @@ def train(config, reporter):
logger = logger_creator
ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config, logger_creator=logger)
trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config, logger_creator=logger)
for i in range(100000 + 2):
print("== Iteration", i, "==")
print("-- PPO --")
print(pretty_print(ppo_trainer.train()))
print(pretty_print(trainer.train()))
if i % config['save_every'] == 0:
checkpoint = ppo_trainer.save()
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
reporter(num_iterations_trained=ppo_trainer._iteration)
reporter(num_iterations_trained=trainer._iteration)
@gin.configurable
......@@ -151,6 +134,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search'
dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search' # To Modify
gin.parse_config_file(dir + '/config.gin')
run_grid_search(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment