Skip to content
Snippets Groups Projects
Commit 2a6f9556 authored by gmollard's avatar gmollard
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines

parents 14a359da b42940e3
No related branches found
No related tags found
No related merge requests found
...@@ -55,12 +55,14 @@ class CustomPreprocessor(Preprocessor): ...@@ -55,12 +55,14 @@ class CustomPreprocessor(Preprocessor):
# return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),) # return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),)
def transform(self, observation): def transform(self, observation):
print('OBSSSSSSSSSSSSSSSSSs', observation, observation.shape) data = norm_obs_clip(observation[0][0])
data = norm_obs_clip(observation[0]) distance = norm_obs_clip(observation[0][1])
distance = norm_obs_clip(observation[1]) agent_data = np.clip(observation[0][2], -1, 1)
agent_data = np.clip(observation[2], -1, 1) data2 = norm_obs_clip(observation[1][0])
distance2 = norm_obs_clip(observation[1][1])
return np.concatenate((np.concatenate((data, distance)), agent_data)) agent_data2 = np.clip(observation[1][2], -1, 1)
return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2))))
return norm_obs_clip(observation) return norm_obs_clip(observation)
return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])]) return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])])
# if len(observation) == 111: # if len(observation) == 111:
......
...@@ -3,12 +3,12 @@ run_experiment.num_iterations = 2002 ...@@ -3,12 +3,12 @@ run_experiment.num_iterations = 2002
run_experiment.save_every = 100 run_experiment.save_every = 100
run_experiment.hidden_sizes = [32, 32] run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20 run_experiment.map_width = 40
run_experiment.map_height = 20 run_experiment.map_height = 40
run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]} run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]}
run_experiment.rail_generator = "complex_rail_generator" run_experiment.rail_generator = "complex_rail_generator"
run_experiment.nr_extra = 5 run_experiment.nr_extra = 5
run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}_" run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}__map_size_{config[map_width]}"
#run_experiment.horizon = #run_experiment.horizon =
run_experiment.seed = 123 run_experiment.seed = 123
...@@ -18,12 +18,13 @@ run_experiment.conv_model = False ...@@ -18,12 +18,13 @@ run_experiment.conv_model = False
#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]} #run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
run_experiment.obs_builder = @TreeObsForRailEnv() run_experiment.obs_builder = @TreeObsForRailEnv()
TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv()
TreeObsForRailEnv.max_depth = 2 TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5 LocalObsForRailEnv.view_radius = 5
run_experiment.entropy_coeff = 0.001 run_experiment.entropy_coeff = 0.001
run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]} run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]}
run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]} run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]}
#run_experiment.predictor = "ShortestPathPredictorForRailEnv" #run_experiment.predictor = "ShortestPathPredictorForRailEnv()"
run_experiment.step_memory = 2 run_experiment.step_memory = 2
run_experiment.min_dist = 10
...@@ -80,12 +80,11 @@ def train(config, reporter): ...@@ -80,12 +80,11 @@ def train(config, reporter):
"seed": config['seed'], "seed": config['seed'],
"obs_builder": config['obs_builder'], "obs_builder": config['obs_builder'],
"min_dist": config['min_dist'], "min_dist": config['min_dist'],
"predictor": config["predictor"],
"step_memory": config["step_memory"]} "step_memory": config["step_memory"]}
# Observation space and action space definitions # Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv): if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)), )) obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2)
preprocessor = "tree_obs_prep" preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv): elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
...@@ -190,7 +189,7 @@ def train(config, reporter): ...@@ -190,7 +189,7 @@ def train(config, reporter):
@gin.configurable @gin.configurable
def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, map_width, map_height, policy_folder_name, local_dir, obs_builder,
entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae, entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
step_memory, min_dist): step_memory, min_dist):
tune.run( tune.run(
...@@ -203,7 +202,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -203,7 +202,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"map_width": map_width, "map_width": map_width,
"map_height": map_height, "map_height": map_height,
"local_dir": local_dir, "local_dir": local_dir,
"horizon": horizon, # Max number of time steps
'policy_folder_name': policy_folder_name, 'policy_folder_name': policy_folder_name,
"obs_builder": obs_builder, "obs_builder": obs_builder,
"entropy_coeff": entropy_coeff, "entropy_coeff": entropy_coeff,
...@@ -230,7 +228,7 @@ if __name__ == '__main__': ...@@ -230,7 +228,7 @@ if __name__ == '__main__':
gin.external_configurable(tune.grid_search) gin.external_configurable(tune.grid_search)
# with path('RLLib_training.experiment_configs.n_agents_experiment', 'config.gin') as f: # with path('RLLib_training.experiment_configs.n_agents_experiment', 'config.gin') as f:
# gin.parse_config_file(f) # gin.parse_config_file(f)
gin.parse_config_file('/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test/config.gin') gin.parse_config_file('/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin')
dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test' dir = '/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents'
# dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory') # dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory')
run_experiment(local_dir=dir) run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment