Skip to content
Snippets Groups Projects
Commit 2a6f9556 authored by gmollard's avatar gmollard
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines

parents 14a359da b42940e3
No related branches found
No related tags found
No related merge requests found
......@@ -55,12 +55,14 @@ class CustomPreprocessor(Preprocessor):
# return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),)
def transform(self, observation):
print('OBSSSSSSSSSSSSSSSSSs', observation, observation.shape)
data = norm_obs_clip(observation[0])
distance = norm_obs_clip(observation[1])
agent_data = np.clip(observation[2], -1, 1)
return np.concatenate((np.concatenate((data, distance)), agent_data))
data = norm_obs_clip(observation[0][0])
distance = norm_obs_clip(observation[0][1])
agent_data = np.clip(observation[0][2], -1, 1)
data2 = norm_obs_clip(observation[1][0])
distance2 = norm_obs_clip(observation[1][1])
agent_data2 = np.clip(observation[1][2], -1, 1)
return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2))))
return norm_obs_clip(observation)
return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])])
# if len(observation) == 111:
......
......@@ -3,12 +3,12 @@ run_experiment.num_iterations = 2002
run_experiment.save_every = 100
run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20
run_experiment.map_height = 20
run_experiment.map_width = 40
run_experiment.map_height = 40
run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]}
run_experiment.rail_generator = "complex_rail_generator"
run_experiment.nr_extra = 5
run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}_"
run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}__map_size_{config[map_width]}"
#run_experiment.horizon =
run_experiment.seed = 123
......@@ -18,12 +18,13 @@ run_experiment.conv_model = False
#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
run_experiment.obs_builder = @TreeObsForRailEnv()
TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv
TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv()
TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5
run_experiment.entropy_coeff = 0.001
run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]}
run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]}
#run_experiment.predictor = "ShortestPathPredictorForRailEnv"
#run_experiment.predictor = "ShortestPathPredictorForRailEnv()"
run_experiment.step_memory = 2
run_experiment.min_dist = 10
......@@ -80,12 +80,11 @@ def train(config, reporter):
"seed": config['seed'],
"obs_builder": config['obs_builder'],
"min_dist": config['min_dist'],
"predictor": config["predictor"],
"step_memory": config["step_memory"]}
# Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)), ))
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2)
preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
......@@ -190,7 +189,7 @@ def train(config, reporter):
@gin.configurable
def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
map_width, map_height, policy_folder_name, local_dir, obs_builder,
entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
step_memory, min_dist):
tune.run(
......@@ -203,7 +202,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"map_width": map_width,
"map_height": map_height,
"local_dir": local_dir,
"horizon": horizon, # Max number of time steps
'policy_folder_name': policy_folder_name,
"obs_builder": obs_builder,
"entropy_coeff": entropy_coeff,
......@@ -230,7 +228,7 @@ if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
# with path('RLLib_training.experiment_configs.n_agents_experiment', 'config.gin') as f:
# gin.parse_config_file(f)
gin.parse_config_file('/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test/config.gin')
dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test'
gin.parse_config_file('/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin')
dir = '/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents'
# dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory')
run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment