diff --git a/README.md b/README.md index 24f629aab61881cd11473f1e062ee76974bedb6b..24f06db7f764afc80b4a40655ef639a1db4703cf 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,40 @@ It should be clone inside the main flatland repository. To start a grid search on some parameters, you can create a folder containing a config.gin file (see example in `grid_search_configs/n_agents_grid_search/config.gin`. -Then, you can modify the config.gin file path at the end of the grid_search_train.py file. +Then, you can modify the config.gin file path at the end of the `grid_search_train.py` file. The results will be stored inside the folder, and the learning curves can be visualized in tensorboard: -`tensorboard --logdir=/path/to/foler_containing_config_gin_file`. \ No newline at end of file +`tensorboard --logdir=/path/to/foler_containing_config_gin_file`. + +## Gin config files + +In each config.gin files, all the parameters, except `local_dir` of the `run_experiment` functions have to be specified. +For example, to indicate the number of agents that have to be initialized at the beginning of each simulation, the following line should be added: + +`run_experiment.n_agents = 2` + +If several number of agents have to be explored during the experiment, one can pass the following value to the `n_agents` parameter: + +`run_experiment.n_agents = {"grid_search": [2,5]}` + +which is the way to indicate to the tune library to experiment several values for a parameter. + +To reference a class or an object within gin, you should first register it from the `train_experiment.py` script adding the following line: + +`gin.external_configurable(TreeObsForRailEnv)` + +and then a `TreeObsForRailEnv` object can be referenced in the `config.gin` file: + +` +run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]} +TreeObsForRailEnv.max_depth = 2 +` + +Note that `@TreeObsForRailEnv` references the class, while `@TreeObsForRailEnv()` references instantiates an object of this class. + + + + +More documentation on how to use gin-config can be found on the library github repository: https://github.com/google/gin-config diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py index 007568e18d0195d8f7536709517f908b41500a7c..e35a46e6e8b551a08ebd78b15029da25fb8c80ec 100644 --- a/RailEnvRLLibWrapper.py +++ b/RailEnvRLLibWrapper.py @@ -1,6 +1,6 @@ from flatland.envs.rail_env import RailEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv -from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.generators import random_rail_generator from ray.rllib.utils.seed import seed as set_seed @@ -17,8 +17,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): seed=config['seed'] * (1+config.vector_index)) set_seed(config['seed'] * (1+config.vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator, - number_of_agents=config["number_of_agents"]) - + number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder']) + def reset(self): self.agents_done = [] return self.env.reset() diff --git a/experiment_configs/observation_benchmark/config.gin b/experiment_configs/observation_benchmark/config.gin index 2d9e7264440df585f5458442ada509d1ebee5678..1fdadd2d4b4236978c1d2c7866c957683d106fd6 100644 --- a/experiment_configs/observation_benchmark/config.gin +++ b/experiment_configs/observation_benchmark/config.gin @@ -1,4 +1,4 @@ -run_experiment.name = "n_agents_results" +run_experiment.name = "observation_benchmark_results" run_experiment.num_iterations = 1002 run_experiment.save_every = 200 run_experiment.hidden_sizes = [32, 32] @@ -6,11 +6,11 @@ run_experiment.hidden_sizes = [32, 32] run_experiment.map_width = 20 run_experiment.map_height = 20 run_experiment.n_agents = {"grid_search": [2, 5]} -run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents" +run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_{config[n_agents]}_agents" run_experiment.horizon = 50 run_experiment.seed = 123 -run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv, @GlobalObsForRailEnv]} +run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]} TreeObsForRailEnv.max_depth = 2 diff --git a/train_experiment.py b/train_experiment.py index b53bac63c3f7e2023728ca81e7ca1a6fe766d365..d8154164e7abb69dd92dc726b4278d4caf0c16eb 100644 --- a/train_experiment.py +++ b/train_experiment.py @@ -27,9 +27,13 @@ from ray import tune from ray.rllib.utils.seed import seed as set_seed from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv +gin.external_configurable(TreeObsForRailEnv) +gin.external_configurable(GlobalObsForRailEnv) + from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor) +ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor) ray.init() @@ -55,20 +59,23 @@ def train(config, reporter): "height": config['map_height'], "rail_generator": complex_rail_generator, "number_of_agents": config['n_agents'], - "seed": config['seed']} + "seed": config['seed'], + "obs_builder": config['obs_builder']} + print(config["obs_builder"]) + print(config["obs_builder"].__class__) + print(type(TreeObsForRailEnv)) # Observation space and action space definitions - if type(config["obs_builder"]) == TreeObsForRailEnv: + if isinstance(config["obs_builder"], TreeObsForRailEnv): obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) preprocessor = "tree_obs_prep" - elif type(config["obs_builder"]) == GlobalObsForRailEnv: + elif isinstance(config["obs_builder"], GlobalObsForRailEnv): obs_space = gym.spaces.Tuple(( gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])), - gym.spaces.Space(4))) - - preprocessor = TupleFlatteningPreprocessor + gym.spaces.Box(low=0, high=1, shape=(4,)))) + preprocessor = "global_obs_prep" else: raise ValueError("Undefined observation space") @@ -94,11 +101,11 @@ def train(config, reporter): trainer_config["horizon"] = config['horizon'] trainer_config["num_workers"] = 0 - trainer_config["num_cpus_per_worker"] = 1 - trainer_config["num_gpus"] = 0.0 - trainer_config["num_gpus_per_worker"] = 0 - trainer_config["num_cpus_for_driver"] = 1 - trainer_config["num_envs_per_worker"] = 1 + trainer_config["num_cpus_per_worker"] = 8 + trainer_config["num_gpus"] = 0.5 + trainer_config["num_gpus_per_worker"] = 0.5 + trainer_config["num_cpus_for_driver"] = 2 + trainer_config["num_envs_per_worker"] = 10 trainer_config["env_config"] = env_config trainer_config["batch_mode"] = "complete_episodes" trainer_config['simple_optimizer'] = False @@ -130,7 +137,7 @@ def train(config, reporter): @gin.configurable def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, - map_width, map_height, horizon, policy_folder_name, local_dir, seed): + map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, seed): tune.run( train, @@ -144,10 +151,11 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, "local_dir": local_dir, "horizon": horizon, # Max number of time steps 'policy_folder_name': policy_folder_name, + "obs_builder": obs_builder, "seed": seed }, resources_per_trial={ - "cpu": 12, + "cpu": 10, "gpu": 0.5 }, local_dir=local_dir