Skip to content
Snippets Groups Projects
Commit bc400346 authored by gmollard's avatar gmollard
Browse files

removed the term 'grid search'

parent b40798ff
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ from flatland.envs.rail_env import RailEnv ...@@ -2,7 +2,7 @@ from flatland.envs.rail_env import RailEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator from flatland.envs.generators import random_rail_generator
from ray.rllib.utils.seed import seed as set_seed
class RailEnvRLLibWrapper(MultiAgentEnv): class RailEnvRLLibWrapper(MultiAgentEnv):
...@@ -13,8 +13,10 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -13,8 +13,10 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# number_of_agents=1, # number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2)): # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(MultiAgentEnv, self).__init__() super(MultiAgentEnv, self).__init__()
self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=config["rail_generator"], seed=config['seed'] * (1+config.vector_index))
set_seed(config['seed'] * (1+config.vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
number_of_agents=config["number_of_agents"]) number_of_agents=config["number_of_agents"])
def reset(self): def reset(self):
......
run_experiment.name = "n_agents_results"
run_experiment.num_iterations = 1002
run_experiment.save_every = 200
run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20
run_experiment.map_height = 20
run_experiment.n_agents = {"grid_search": [1]}#, 2, 5, 10]}
run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
run_experiment.horizon = 50
run_experiment.seed = 123
run_grid_search.name = "n_agents_results"
run_grid_search.num_iterations = 1002
run_grid_search.save_every = 200
run_grid_search.hidden_sizes = [32, 32]
run_grid_search.map_width = 20
run_grid_search.map_height = 20
run_grid_search.n_agents = {"grid_search": [1, 2, 5, 10]}
run_grid_search.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
run_grid_search.horizon = 50
...@@ -24,15 +24,17 @@ import tempfile ...@@ -24,15 +24,17 @@ import tempfile
import gin import gin
from ray import tune from ray import tune
from ray.rllib.utils.seed import seed as set_seed
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
ray.init(object_store_memory=150000000000) ray.init()
def train(config, reporter): def train(config, reporter):
print('Init Env') print('Init Env')
set_seed(config['seed'], config['seed'], config['seed'])
transition_probability = [15, # empty cell - Case 0 transition_probability = [15, # empty cell - Case 0
5, # Case 1 - straight 5, # Case 1 - straight
5, # Case 2 - simple switch 5, # Case 2 - simple switch
...@@ -46,10 +48,11 @@ def train(config, reporter): ...@@ -46,10 +48,11 @@ def train(config, reporter):
1] # Case 2b (10) - simple switch mirrored 1] # Case 2b (10) - simple switch mirrored
# Example configuration to generate a random rail # Example configuration to generate a random rail
env_config = {"width":config['map_width'], env_config = {"width": config['map_width'],
"height":config['map_height'], "height": config['map_height'],
"rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0), "rail_generator": complex_rail_generator,
"number_of_agents":config['n_agents']} "number_of_agents": config['n_agents'],
"seed": config['seed']}
# Observation space and action space definitions # Observation space and action space definitions
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
...@@ -73,11 +76,11 @@ def train(config, reporter): ...@@ -73,11 +76,11 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon'] trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0 trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 10 trainer_config["num_cpus_per_worker"] = 1
trainer_config["num_gpus"] = 0.5 trainer_config["num_gpus"] = 0.0
trainer_config["num_gpus_per_worker"] = 0.5 trainer_config["num_gpus_per_worker"] = 0
trainer_config["num_cpus_for_driver"] = 2 trainer_config["num_cpus_for_driver"] = 1
trainer_config["num_envs_per_worker"] = 10 trainer_config["num_envs_per_worker"] = 1
trainer_config["env_config"] = env_config trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes" trainer_config["batch_mode"] = "complete_episodes"
trainer_config['simple_optimizer'] = False trainer_config['simple_optimizer'] = False
...@@ -108,8 +111,8 @@ def train(config, reporter): ...@@ -108,8 +111,8 @@ def train(config, reporter):
@gin.configurable @gin.configurable
def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every, def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir): map_width, map_height, horizon, policy_folder_name, local_dir, seed):
tune.run( tune.run(
train, train,
...@@ -122,11 +125,12 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -122,11 +125,12 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
"map_height": map_height, "map_height": map_height,
"local_dir": local_dir, "local_dir": local_dir,
"horizon": horizon, # Max number of time steps "horizon": horizon, # Max number of time steps
'policy_folder_name': policy_folder_name 'policy_folder_name': policy_folder_name,
"seed": seed
}, },
resources_per_trial={ resources_per_trial={
"cpu": 12, "cpu": 1,
"gpu": 0.5 "gpu": 0.0
}, },
local_dir=local_dir local_dir=local_dir
) )
...@@ -134,6 +138,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -134,6 +138,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__': if __name__ == '__main__':
gin.external_configurable(tune.grid_search) gin.external_configurable(tune.grid_search)
dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search' # To Modify dir = '/home/guillaume/Desktop/distMAgent/baselines/experiment_configs/n_agents_experiment' # To Modify
gin.parse_config_file(dir + '/config.gin') gin.parse_config_file(dir + '/config.gin')
run_grid_search(local_dir=dir) run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment