Skip to content
Snippets Groups Projects
Commit 779e87d5 authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

env complexity grid search

parent 540dce24
No related branches found
No related tags found
No related merge requests found
...@@ -25,11 +25,13 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -25,11 +25,13 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5, self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index)) nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index))
else: else:
raise(Error)
self.rail_generator = random_rail_generator() self.rail_generator = random_rail_generator()
set_seed(config['seed'] * (1+vector_index)) set_seed(config['seed'] * (1+vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], self.env = RailEnv(width=config["width"], height=config["height"],
number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder']) number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator)
# self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') # self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
......
...@@ -4,10 +4,10 @@ run_experiment.save_every = 50 ...@@ -4,10 +4,10 @@ run_experiment.save_every = 50
run_experiment.hidden_sizes = [32, 32] run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20 run_experiment.map_width = 20
run_experiment.map_height = 10 run_experiment.map_height = 20
run_experiment.n_agents = 8 run_experiment.n_agents = 8
run_experiment.rail_generator = "complex_rail_generator" run_experiment.rail_generator = "complex_rail_generator"
run_experiment.nr_extra = {"grid_search": [10, 20, 30, 40]} run_experiment.nr_extra = {"grid_search": [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]}
run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_" run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_"
run_experiment.horizon = 50 run_experiment.horizon = 50
......
...@@ -50,10 +50,6 @@ def train(config, reporter): ...@@ -50,10 +50,6 @@ def train(config, reporter):
set_seed(config['seed'], config['seed'], config['seed']) set_seed(config['seed'], config['seed'], config['seed'])
config['map_width']= 20
config['map_height']= 10
config['n_agents'] = 8
# Example configuration to generate a random rail # Example configuration to generate a random rail
env_config = {"width": config['map_width'], env_config = {"width": config['map_width'],
"height": config['map_height'], "height": config['map_height'],
...@@ -136,6 +132,8 @@ def train(config, reporter): ...@@ -136,6 +132,8 @@ def train(config, reporter):
trainer_config['simple_optimizer'] = False trainer_config['simple_optimizer'] = False
trainer_config['postprocess_inputs'] = True trainer_config['postprocess_inputs'] = True
trainer_config['log_level'] = 'WARN' trainer_config['log_level'] = 'WARN'
trainer_config['num_sgd_iter'] = 10
trainer_config['clip_param'] = 0.2
def logger_creator(conf): def logger_creator(conf):
"""Creates a Unified logger with a default logdir prefix """Creates a Unified logger with a default logdir prefix
...@@ -187,7 +185,7 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -187,7 +185,7 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"nr_extra": nr_extra "nr_extra": nr_extra
}, },
resources_per_trial={ resources_per_trial={
"cpu": 2, "cpu": 4,
"gpu": 0.0 "gpu": 0.0
}, },
local_dir=local_dir local_dir=local_dir
...@@ -196,6 +194,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -196,6 +194,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__': if __name__ == '__main__':
gin.external_configurable(tune.grid_search) gin.external_configurable(tune.grid_search)
dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/RLLib_training/experiment_configs/env_complexity_benchmark' # To Modify dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/env_complexity_benchmark' # To Modify
gin.parse_config_file(dir + '/config.gin') gin.parse_config_file(dir + '/config.gin')
run_experiment(local_dir=dir) run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment