Skip to content
Snippets Groups Projects
Commit 84bcfcb1 authored by gmollard's avatar gmollard
Browse files

up to data with shortest path prediction

parent fdd64ac1
No related branches found
No related tags found
No related merge requests found
......@@ -55,12 +55,14 @@ class CustomPreprocessor(Preprocessor):
# return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),)
def transform(self, observation):
print('OBSSSSSSSSSSSSSSSSSs', observation, observation.shape)
data = norm_obs_clip(observation[0])
distance = norm_obs_clip(observation[1])
agent_data = np.clip(observation[2], -1, 1)
return np.concatenate((np.concatenate((data, distance)), agent_data))
data = norm_obs_clip(observation[0][0])
distance = norm_obs_clip(observation[0][1])
agent_data = np.clip(observation[0][2], -1, 1)
data2 = norm_obs_clip(observation[1][0])
distance2 = norm_obs_clip(observation[1][1])
agent_data2 = np.clip(observation[1][2], -1, 1)
return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2))))
return norm_obs_clip(observation)
return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])])
# if len(observation) == 111:
......
......@@ -18,12 +18,13 @@ run_experiment.conv_model = False
#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
run_experiment.obs_builder = @TreeObsForRailEnv()
TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv
TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv()
TreeObsForRailEnv.max_depth = 2
LocalObsForRailEnv.view_radius = 5
run_experiment.entropy_coeff = 0.001
run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]}
run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]}
#run_experiment.predictor = "ShortestPathPredictorForRailEnv"
#run_experiment.predictor = "ShortestPathPredictorForRailEnv()"
run_experiment.step_memory = 2
run_experiment.min_dist = 10
......@@ -83,12 +83,11 @@ def train(config, reporter):
"seed": config['seed'],
"obs_builder": config['obs_builder'],
"min_dist": config['min_dist'],
"predictor": config["predictor"],
"step_memory": config["step_memory"]}
# Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv):
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)), ))
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2)
preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
......@@ -193,7 +192,7 @@ def train(config, reporter):
@gin.configurable
def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
map_width, map_height, policy_folder_name, local_dir, obs_builder,
entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
step_memory, min_dist):
tune.run(
......@@ -206,7 +205,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"map_width": map_width,
"map_height": map_height,
"local_dir": local_dir,
"horizon": horizon, # Max number of time steps
'policy_folder_name': policy_folder_name,
"obs_builder": obs_builder,
"entropy_coeff": entropy_coeff,
......@@ -233,7 +231,7 @@ if __name__ == '__main__':
gin.external_configurable(tune.grid_search)
# with path('RLLib_training.experiment_configs.n_agents_experiment', 'config.gin') as f:
# gin.parse_config_file(f)
gin.parse_config_file('/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test/config.gin')
dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/score_metric_test'
gin.parse_config_file('/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin')
dir = '/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents'
# dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory')
run_experiment(local_dir=dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment