diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index 5642520a0e3649878789c093b07b3f0f06fb3f32..c54884df9f8c647c0f13136649828b309d6b075e 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -25,11 +25,13 @@ class RailEnvRLLibWrapper(MultiAgentEnv): self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5, nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index)) else: + raise(Error) self.rail_generator = random_rail_generator() set_seed(config['seed'] * (1+vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], - number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder']) + number_of_agents=config["number_of_agents"], + obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator) # self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') diff --git a/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin b/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin index 82305a640693dbfa7efd946e3eb671727e0f72a5..7f7f075528d081750674695f82fef4f595d81355 100644 --- a/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin +++ b/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin @@ -4,10 +4,10 @@ run_experiment.save_every = 50 run_experiment.hidden_sizes = [32, 32] run_experiment.map_width = 20 -run_experiment.map_height = 10 +run_experiment.map_height = 20 run_experiment.n_agents = 8 run_experiment.rail_generator = "complex_rail_generator" -run_experiment.nr_extra = {"grid_search": [10, 20, 30, 40]} +run_experiment.nr_extra = {"grid_search": [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]} run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_" run_experiment.horizon = 50 diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index 57fb0ceb642cabd562075c3aae8e7cd8e6240460..8e74d12a29d58781d1feba2a0af11973882d5c3a 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -50,10 +50,6 @@ def train(config, reporter): set_seed(config['seed'], config['seed'], config['seed']) - config['map_width']= 20 - config['map_height']= 10 - config['n_agents'] = 8 - # Example configuration to generate a random rail env_config = {"width": config['map_width'], "height": config['map_height'], @@ -136,6 +132,8 @@ def train(config, reporter): trainer_config['simple_optimizer'] = False trainer_config['postprocess_inputs'] = True trainer_config['log_level'] = 'WARN' + trainer_config['num_sgd_iter'] = 10 + trainer_config['clip_param'] = 0.2 def logger_creator(conf): """Creates a Unified logger with a default logdir prefix @@ -187,7 +185,7 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, "nr_extra": nr_extra }, resources_per_trial={ - "cpu": 2, + "cpu": 4, "gpu": 0.0 }, local_dir=local_dir @@ -196,6 +194,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, if __name__ == '__main__': gin.external_configurable(tune.grid_search) - dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/RLLib_training/experiment_configs/env_complexity_benchmark' # To Modify + dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/env_complexity_benchmark' # To Modify gin.parse_config_file(dir + '/config.gin') run_experiment(local_dir=dir)