Skip to content
Snippets Groups Projects
Commit bc400346 authored by gmollard's avatar gmollard
Browse files

removed the term 'grid search'

parent b40798ff
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ from flatland.envs.rail_env import RailEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from ray.rllib.utils.seed import seed as set_seed
class RailEnvRLLibWrapper(MultiAgentEnv):
......@@ -13,8 +13,10 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(MultiAgentEnv, self).__init__()
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=config["rail_generator"],
self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
seed=config['seed'] * (1+config.vector_index))
set_seed(config['seed'] * (1+config.vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
number_of_agents=config["number_of_agents"])
def reset(self):
......
run_experiment.name = "n_agents_results"
run_experiment.num_iterations = 1002
run_experiment.save_every = 200
run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20
run_experiment.map_height = 20
run_experiment.n_agents = {"grid_search": [1]}#, 2, 5, 10]}
run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
run_experiment.horizon = 50
run_experiment.seed = 123
run_grid_search.name = "n_agents_results"
run_grid_search.num_iterations = 1002
run_grid_search.save_every = 200
run_grid_search.hidden_sizes = [32, 32]
run_grid_search.map_width = 20
run_grid_search.map_height = 20
run_grid_search.n_agents = {"grid_search": [1, 2, 5, 10]}
run_grid_search.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
run_grid_search.horizon = 50
......@@ -24,15 +24,17 @@ import tempfile
import gin
from ray import tune
from ray.rllib.utils.seed import seed as set_seed
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
ray.init(object_store_memory=150000000000)
ray.init()
def train(config, reporter):
print('Init Env')
set_seed(config['seed'], config['seed'], config['seed'])
transition_probability = [15, # empty cell - Case 0
5, # Case 1 - straight
5, # Case 2 - simple switch
......@@ -46,10 +48,11 @@ def train(config, reporter):
1] # Case 2b (10) - simple switch mirrored
# Example configuration to generate a random rail
env_config = {"width":config['map_width'],
"height":config['map_height'],
"rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0),
"number_of_agents":config['n_agents']}
env_config = {"width": config['map_width'],
"height": config['map_height'],
"rail_generator": complex_rail_generator,
"number_of_agents": config['n_agents'],
"seed": config['seed']}
# Observation space and action space definitions
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
......@@ -73,11 +76,11 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 10
trainer_config["num_gpus"] = 0.5
trainer_config["num_gpus_per_worker"] = 0.5
trainer_config["num_cpus_for_driver"] = 2
trainer_config["num_envs_per_worker"] = 10
trainer_config["num_cpus_per_worker"] = 1
trainer_config["num_gpus"] = 0.0
trainer_config["num_gpus_per_worker"] = 0
trainer_config["num_cpus_for_driver"] = 1
trainer_config["num_envs_per_worker"] = 1
trainer_config["env_config"] = env_config
trainer_config["batch_mode"] = "complete_episodes"
trainer_config['simple_optimizer'] = False
......@@ -108,8 +111,8 @@ def train(config, reporter):
@gin.configurable
def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir):
def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, seed):
tune.run(
train,
......@@ -122,11 +125,12 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
"map_height": map_height,
"local_dir": local_dir,
"horizon": horizon, # Max number of time steps
'policy_folder_name': policy_folder_name
'policy_folder_name': policy_folder_name,
"seed": seed
},
resources_per_trial={
"cpu": 12,
"gpu": 0.5
"cpu": 1,
"gpu": 0.0
},
local_dir=local_dir
)
......@@ -134,6 +138,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search' # To Modify
dir = '/home/guillaume/Desktop/distMAgent/baselines/experiment_configs/n_agents_experiment' # To Modify
gin.parse_config_file(dir + '/config.gin')
run_grid_search(local_dir=dir)
run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment