diff --git a/experiment_configs/entropy_coeff_benchmark/config.gin b/experiment_configs/entropy_coeff_benchmark/config.gin new file mode 100644 index 0000000000000000000000000000000000000000..e674447137f4efe2962f36d26d446bf5e99af073 --- /dev/null +++ b/experiment_configs/entropy_coeff_benchmark/config.gin @@ -0,0 +1,19 @@ +run_experiment.name = "observation_benchmark_results" +run_experiment.num_iterations = 1002 +run_experiment.save_every = 100 +run_experiment.hidden_sizes = {"grid_search": [[32, 32], [64, 64], [128, 128], [256, 256]} + +run_experiment.map_width = 20 +run_experiment.map_height = 20 +run_experiment.n_agents = 5 +run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_entropy_coeff_{config[entropy_coeff]}_{config[hidden_sizes][0]}_hidden_sizes_" + +run_experiment.horizon = 50 +run_experiment.seed = 123 + +run_experiment.entropy_coeff = {"grid_search": [1e-3, 1e-2, 0]} + +run_experiment.obs_builder = {"grid_search": [@LocalObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]} +TreeObsForRailEnv.max_depth = 2 +LocalObsForRailEnv.view_radius = 5 + diff --git a/train_experiment.py b/train_experiment.py index 16c52b657db58c6f69e8cf91004d4259461be2cd..330acedb07a554780ddb4fdadbb9d9b14be626f0 100644 --- a/train_experiment.py +++ b/train_experiment.py @@ -116,6 +116,7 @@ def train(config, reporter): trainer_config["num_gpus_per_worker"] = 0 trainer_config["num_cpus_for_driver"] = 1 trainer_config["num_envs_per_worker"] = 1 + trainer_config['entropy_coeff'] = config['entropy_coeff'] trainer_config["env_config"] = env_config trainer_config["batch_mode"] = "complete_episodes" trainer_config['simple_optimizer'] = False @@ -149,7 +150,8 @@ def train(config, reporter): @gin.configurable def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, - map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, seed): + map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, + entropy_coeff, seed): tune.run( train, @@ -164,6 +166,7 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, "horizon": horizon, # Max number of time steps 'policy_folder_name': policy_folder_name, "obs_builder": obs_builder, + "entropy_coeff": entropy_coeff, "seed": seed }, resources_per_trial={