Skip to content
Snippets Groups Projects
Commit b9836597 authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

more parameters added for PPO training

parent 779e87d5
No related branches found
No related tags found
No related merge requests found
...@@ -44,6 +44,14 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -44,6 +44,14 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
self.agents_done = [] self.agents_done = []
obs = self.env.reset() obs = self.env.reset()
o = dict() o = dict()
#for agent, _ in obs.items():
#o[agent] = obs[agent]
# one_hot_agent_encoding = np.zeros(len(self.env.agents))
# one_hot_agent_encoding[agent] += 1
# o[agent] = np.append(obs[agent], one_hot_agent_encoding)
# o['agents'] = obs # o['agents'] = obs
# obs[0] = [obs[0], np.ones((17, 17)) * 17] # obs[0] = [obs[0], np.ones((17, 17)) * 17]
# obs['global_obs'] = np.ones((17, 17)) * 17 # obs['global_obs'] = np.ones((17, 17)) * 17
...@@ -67,7 +75,10 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -67,7 +75,10 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
for agent, done in dones.items(): for agent, done in dones.items():
if agent not in self.agents_done: if agent not in self.agents_done:
if agent != '__all__': if agent != '__all__':
o[agent] = obs[agent] # o[agent] = obs[agent]
#one_hot_agent_encoding = np.zeros(len(self.env.agents))
#one_hot_agent_encoding[agent] += 1
o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding)
r[agent] = rewards[agent] r[agent] = rewards[agent]
d[agent] = dones[agent] d[agent] = dones[agent]
......
...@@ -54,6 +54,8 @@ class CustomPreprocessor(Preprocessor): ...@@ -54,6 +54,8 @@ class CustomPreprocessor(Preprocessor):
def transform(self, observation): def transform(self, observation):
# if len(observation) == 111: # if len(observation) == 111:
return norm_obs_clip(observation) return norm_obs_clip(observation)
one_hot = observation[-3:]
return np.append(obs, one_hot)
# else: # else:
# return observation # return observation
......
...@@ -3,12 +3,12 @@ run_experiment.num_iterations = 1002 ...@@ -3,12 +3,12 @@ run_experiment.num_iterations = 1002
run_experiment.save_every = 50 run_experiment.save_every = 50
run_experiment.hidden_sizes = [32, 32] run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20 run_experiment.map_width = 15
run_experiment.map_height = 20 run_experiment.map_height = 15
run_experiment.n_agents = 8 run_experiment.n_agents = 8
run_experiment.rail_generator = "complex_rail_generator" run_experiment.rail_generator = "complex_rail_generator"
run_experiment.nr_extra = {"grid_search": [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]} run_experiment.nr_extra = {"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_" run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_map_width_{config[map_width]}_"
run_experiment.horizon = 50 run_experiment.horizon = 50
run_experiment.seed = 123 run_experiment.seed = 123
......
run_experiment.name = "observation_benchmark_results"
run_experiment.num_iterations = 2002
run_experiment.save_every = 50
run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 8
run_experiment.map_height = 8
run_experiment.n_agents = 3
run_experiment.rail_generator = "complex_rail_generator"
run_experiment.nr_extra = 5#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_kl_coeff_{config[kl_coeff]}_lambda_gae_{config[lambda_gae]}_horizon_{config[horizon]}_"
run_experiment.horizon = {"grid_search": [30, 50]}
run_experiment.seed = 123
#run_experiment.conv_model = {"grid_search": [True, False]}
run_experiment.conv_model = False
#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]}
TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5
run_experiment.entropy_coeff = 0.01
run_experiment.kl_coeff = {"grid_search": [0, 0.2]}
run_experiment.lambda_gae = {"grid_search": [0.9, 1.0]}
...@@ -121,7 +121,7 @@ def train(config, reporter): ...@@ -121,7 +121,7 @@ def train(config, reporter):
trainer_config["horizon"] = config['horizon'] trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0 trainer_config["num_workers"] = 0
trainer_config["num_cpus_per_worker"] = 3 trainer_config["num_cpus_per_worker"] = 2
trainer_config["num_gpus"] = 0 trainer_config["num_gpus"] = 0
trainer_config["num_gpus_per_worker"] = 0 trainer_config["num_gpus_per_worker"] = 0
trainer_config["num_cpus_for_driver"] = 1 trainer_config["num_cpus_for_driver"] = 1
...@@ -134,6 +134,8 @@ def train(config, reporter): ...@@ -134,6 +134,8 @@ def train(config, reporter):
trainer_config['log_level'] = 'WARN' trainer_config['log_level'] = 'WARN'
trainer_config['num_sgd_iter'] = 10 trainer_config['num_sgd_iter'] = 10
trainer_config['clip_param'] = 0.2 trainer_config['clip_param'] = 0.2
trainer_config['kl_coeff'] = config['kl_coeff']
trainer_config['lambda'] = config['lambda_gae']
def logger_creator(conf): def logger_creator(conf):
"""Creates a Unified logger with a default logdir prefix """Creates a Unified logger with a default logdir prefix
...@@ -163,7 +165,7 @@ def train(config, reporter): ...@@ -163,7 +165,7 @@ def train(config, reporter):
@gin.configurable @gin.configurable
def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
entropy_coeff, seed, conv_model, rail_generator, nr_extra): entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae):
tune.run( tune.run(
train, train,
...@@ -182,10 +184,12 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -182,10 +184,12 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"seed": seed, "seed": seed,
"conv_model": conv_model, "conv_model": conv_model,
"rail_generator": rail_generator, "rail_generator": rail_generator,
"nr_extra": nr_extra "nr_extra": nr_extra,
"kl_coeff": kl_coeff,
"lambda_gae": lambda_gae
}, },
resources_per_trial={ resources_per_trial={
"cpu": 4, "cpu": 3,
"gpu": 0.0 "gpu": 0.0
}, },
local_dir=local_dir local_dir=local_dir
...@@ -194,6 +198,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -194,6 +198,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
if __name__ == '__main__': if __name__ == '__main__':
gin.external_configurable(tune.grid_search) gin.external_configurable(tune.grid_search)
dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/env_complexity_benchmark' # To Modify dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents' # To Modify
gin.parse_config_file(dir + '/config.gin') gin.parse_config_file(dir + '/config.gin')
run_experiment(local_dir=dir) run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment