diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py index 41e895bee95cde7bc59220ce2eee4eaabd458651..2f236f33b89395e4c4169ddd4e35b7ce4606cf6c 100644 --- a/RLLib_training/custom_preprocessors.py +++ b/RLLib_training/custom_preprocessors.py @@ -55,12 +55,14 @@ class CustomPreprocessor(Preprocessor): # return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),) def transform(self, observation): - print('OBSSSSSSSSSSSSSSSSSs', observation, observation.shape) - data = norm_obs_clip(observation[0]) - distance = norm_obs_clip(observation[1]) - agent_data = np.clip(observation[2], -1, 1) - - return np.concatenate((np.concatenate((data, distance)), agent_data)) + data = norm_obs_clip(observation[0][0]) + distance = norm_obs_clip(observation[0][1]) + agent_data = np.clip(observation[0][2], -1, 1) + data2 = norm_obs_clip(observation[1][0]) + distance2 = norm_obs_clip(observation[1][1]) + agent_data2 = np.clip(observation[1][2], -1, 1) + + return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2)))) return norm_obs_clip(observation) return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])]) # if len(observation) == 111: diff --git a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin index bbc3803807c3564e65039b798dbee8691ac2084b..3236d269ce924633ac5fafe4896f3cd91f6e3bd9 100644 --- a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin +++ b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin @@ -18,12 +18,13 @@ run_experiment.conv_model = False #run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]} run_experiment.obs_builder = @TreeObsForRailEnv() -TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv +TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv() TreeObsForRailEnv.max_depth = 2 LocalObsForRailEnv.view_radius = 5 run_experiment.entropy_coeff = 0.001 run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]} run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]} -#run_experiment.predictor = "ShortestPathPredictorForRailEnv" +#run_experiment.predictor = "ShortestPathPredictorForRailEnv()" run_experiment.step_memory = 2 +run_experiment.min_dist = 10 diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index cd25ad0d33a723922ff105bf1cdcfdaed283f3f1..9f9f77e340797d4a6d1820ae14b2744916ed981e 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -83,12 +83,11 @@ def train(config, reporter): "seed": config['seed'], "obs_builder": config['obs_builder'], "min_dist": config['min_dist'], - "predictor": config["predictor"], "step_memory": config["step_memory"]} # Observation space and action space definitions if isinstance(config["obs_builder"], TreeObsForRailEnv): - obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)), )) + obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2) preprocessor = "tree_obs_prep" elif isinstance(config["obs_builder"], GlobalObsForRailEnv): @@ -193,7 +192,7 @@ def train(config, reporter): @gin.configurable def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, - map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, + map_width, map_height, policy_folder_name, local_dir, obs_builder, entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae, step_memory, min_dist): tune.run( @@ -206,7 +205,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, "map_width": map_width, "map_height": map_height, "local_dir": local_dir, - "horizon": horizon, # Max number of time steps 'policy_folder_name': policy_folder_name, "obs_builder": obs_builder, "entropy_coeff": entropy_coeff, @@ -233,7 +231,7 @@ if __name__ == '__main__': gin.external_configurable(tune.grid_search) # with path('RLLib_training.experiment_configs.n_agents_experiment', 'config.gin') as f: # gin.parse_config_file(f) - gin.parse_config_file('/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test/config.gin') - dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test' + gin.parse_config_file('/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin') + dir = '/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents' # dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory') run_experiment(local_dir=dir)