Skip to content
Snippets Groups Projects
Commit 4470e889 authored by gmollard's avatar gmollard
Browse files

observation benchmark working correctly

parent a54b734a
No related branches found
No related tags found
No related merge requests found
...@@ -9,9 +9,40 @@ It should be clone inside the main flatland repository. ...@@ -9,9 +9,40 @@ It should be clone inside the main flatland repository.
To start a grid search on some parameters, you can create a folder containing a config.gin file (see example in `grid_search_configs/n_agents_grid_search/config.gin`. To start a grid search on some parameters, you can create a folder containing a config.gin file (see example in `grid_search_configs/n_agents_grid_search/config.gin`.
Then, you can modify the config.gin file path at the end of the grid_search_train.py file. Then, you can modify the config.gin file path at the end of the `grid_search_train.py` file.
The results will be stored inside the folder, and the learning curves can be visualized in The results will be stored inside the folder, and the learning curves can be visualized in
tensorboard: tensorboard:
`tensorboard --logdir=/path/to/foler_containing_config_gin_file`. `tensorboard --logdir=/path/to/foler_containing_config_gin_file`.
\ No newline at end of file
## Gin config files
In each config.gin files, all the parameters, except `local_dir` of the `run_experiment` functions have to be specified.
For example, to indicate the number of agents that have to be initialized at the beginning of each simulation, the following line should be added:
`run_experiment.n_agents = 2`
If several number of agents have to be explored during the experiment, one can pass the following value to the `n_agents` parameter:
`run_experiment.n_agents = {"grid_search": [2,5]}`
which is the way to indicate to the tune library to experiment several values for a parameter.
To reference a class or an object within gin, you should first register it from the `train_experiment.py` script adding the following line:
`gin.external_configurable(TreeObsForRailEnv)`
and then a `TreeObsForRailEnv` object can be referenced in the `config.gin` file:
`
run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]}
TreeObsForRailEnv.max_depth = 2
`
Note that `@TreeObsForRailEnv` references the class, while `@TreeObsForRailEnv()` references instantiates an object of this class.
More documentation on how to use gin-config can be found on the library github repository: https://github.com/google/gin-config
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator from flatland.envs.generators import random_rail_generator
from ray.rllib.utils.seed import seed as set_seed from ray.rllib.utils.seed import seed as set_seed
...@@ -17,8 +17,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -17,8 +17,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
seed=config['seed'] * (1+config.vector_index)) seed=config['seed'] * (1+config.vector_index))
set_seed(config['seed'] * (1+config.vector_index)) set_seed(config['seed'] * (1+config.vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator, self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
number_of_agents=config["number_of_agents"]) number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder'])
def reset(self): def reset(self):
self.agents_done = [] self.agents_done = []
return self.env.reset() return self.env.reset()
......
run_experiment.name = "n_agents_results" run_experiment.name = "observation_benchmark_results"
run_experiment.num_iterations = 1002 run_experiment.num_iterations = 1002
run_experiment.save_every = 200 run_experiment.save_every = 200
run_experiment.hidden_sizes = [32, 32] run_experiment.hidden_sizes = [32, 32]
...@@ -6,11 +6,11 @@ run_experiment.hidden_sizes = [32, 32] ...@@ -6,11 +6,11 @@ run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20 run_experiment.map_width = 20
run_experiment.map_height = 20 run_experiment.map_height = 20
run_experiment.n_agents = {"grid_search": [2, 5]} run_experiment.n_agents = {"grid_search": [2, 5]}
run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents" run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_{config[n_agents]}_agents"
run_experiment.horizon = 50 run_experiment.horizon = 50
run_experiment.seed = 123 run_experiment.seed = 123
run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv, @GlobalObsForRailEnv]} run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]}
TreeObsForRailEnv.max_depth = 2 TreeObsForRailEnv.max_depth = 2
...@@ -27,9 +27,13 @@ from ray import tune ...@@ -27,9 +27,13 @@ from ray import tune
from ray.rllib.utils.seed import seed as set_seed from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
gin.external_configurable(TreeObsForRailEnv)
gin.external_configurable(GlobalObsForRailEnv)
from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor)
ray.init() ray.init()
...@@ -55,20 +59,23 @@ def train(config, reporter): ...@@ -55,20 +59,23 @@ def train(config, reporter):
"height": config['map_height'], "height": config['map_height'],
"rail_generator": complex_rail_generator, "rail_generator": complex_rail_generator,
"number_of_agents": config['n_agents'], "number_of_agents": config['n_agents'],
"seed": config['seed']} "seed": config['seed'],
"obs_builder": config['obs_builder']}
print(config["obs_builder"])
print(config["obs_builder"].__class__)
print(type(TreeObsForRailEnv))
# Observation space and action space definitions # Observation space and action space definitions
if type(config["obs_builder"]) == TreeObsForRailEnv: if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
preprocessor = "tree_obs_prep" preprocessor = "tree_obs_prep"
elif type(config["obs_builder"]) == GlobalObsForRailEnv: elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
obs_space = gym.spaces.Tuple(( obs_space = gym.spaces.Tuple((
gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)), gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])), gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])),
gym.spaces.Space(4))) gym.spaces.Box(low=0, high=1, shape=(4,))))
preprocessor = "global_obs_prep"
preprocessor = TupleFlatteningPreprocessor
else: else:
raise ValueError("Undefined observation space") raise ValueError("Undefined observation space")
...@@ -94,11 +101,11 @@ def train(config, reporter): ...@@ -94,11 +101,11 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon'] trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0 trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 1 trainer_config["num_cpus_per_worker"] = 8
trainer_config["num_gpus"] = 0.0 trainer_config["num_gpus"] = 0.5
trainer_config["num_gpus_per_worker"] = 0 trainer_config["num_gpus_per_worker"] = 0.5
trainer_config["num_cpus_for_driver"] = 1 trainer_config["num_cpus_for_driver"] = 2
trainer_config["num_envs_per_worker"] = 1 trainer_config["num_envs_per_worker"] = 10
trainer_config["env_config"] = env_config trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes" trainer_config["batch_mode"] = "complete_episodes"
trainer_config['simple_optimizer'] = False trainer_config['simple_optimizer'] = False
...@@ -130,7 +137,7 @@ def train(config, reporter): ...@@ -130,7 +137,7 @@ def train(config, reporter):
@gin.configurable @gin.configurable
def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, seed): map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, seed):
tune.run( tune.run(
train, train,
...@@ -144,10 +151,11 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -144,10 +151,11 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"local_dir": local_dir, "local_dir": local_dir,
"horizon": horizon, # Max number of time steps "horizon": horizon, # Max number of time steps
'policy_folder_name': policy_folder_name, 'policy_folder_name': policy_folder_name,
"obs_builder": obs_builder,
"seed": seed "seed": seed
}, },
resources_per_trial={ resources_per_trial={
"cpu": 12, "cpu": 10,
"gpu": 0.5 "gpu": 0.5
}, },
local_dir=local_dir local_dir=local_dir
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment