diff --git a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
index 258bc1d97ee321b6b61d17d2a010e7310a9351ca..bbc3803807c3564e65039b798dbee8691ac2084b 100644
--- a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
+++ b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
@@ -1,16 +1,16 @@
 run_experiment.name = "observation_benchmark_results"
 run_experiment.num_iterations = 2002
-run_experiment.save_every = 50
+run_experiment.save_every = 100
 run_experiment.hidden_sizes = [32, 32]
 
-run_experiment.map_width = 8
-run_experiment.map_height = 8
-run_experiment.n_agents = 3
+run_experiment.map_width = 20
+run_experiment.map_height = 20
+run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]}
 run_experiment.rail_generator = "complex_rail_generator"
-run_experiment.nr_extra = 5#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
-run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_kl_coeff_{config[kl_coeff]}_horizon_{config[horizon]}_"
+run_experiment.nr_extra = 5
+run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}_"
 
-run_experiment.horizon = {"grid_search": [50, 100]}
+#run_experiment.horizon =
 run_experiment.seed = 123
 
 #run_experiment.conv_model = {"grid_search": [True, False]}
@@ -18,9 +18,12 @@ run_experiment.conv_model = False
 
 #run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
 run_experiment.obs_builder = @TreeObsForRailEnv()
+TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv
 TreeObsForRailEnv.max_depth = 2
 LocalObsForRailEnv.view_radius = 5
 
-run_experiment.entropy_coeff = 0.01
-run_experiment.kl_coeff = {"grid_search": [0, 0.2]}
-run_experiment.lambda_gae = 0.9# {"grid_search": [0.9, 1.0]}
+run_experiment.entropy_coeff = 0.001
+run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]}
+run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]}
+#run_experiment.predictor = "ShortestPathPredictorForRailEnv"
+run_experiment.step_memory = 2
diff --git a/RLLib_training/experiment_configs/experiment_agent_memory/config.gin b/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
index 58df080ecd624d0076f06b81b785bb4ac29d3139..4de08004b5089f2b607052de95ac1b9cad30d0ab 100644
--- a/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
+++ b/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
@@ -1,7 +1,7 @@
 run_experiment.name = "memory_experiment_results"
 run_experiment.num_iterations = 2002
 run_experiment.save_every = 50
-run_experiment.hidden_sizes = {"grid_search": [[32, 32], [64, 64], [128, 128]]}
+run_experiment.hidden_sizes = [32, 32]#{"grid_search": [[32, 32], [64, 64], [128, 128]]}
 
 run_experiment.map_width = 8
 run_experiment.map_height = 8
@@ -20,7 +20,7 @@ run_experiment.obs_builder = @TreeObsForRailEnv()
 TreeObsForRailEnv.max_depth = 2
 LocalObsForRailEnv.view_radius = 5
 
-run_experiment.entropy_coeff = {"grid_search": [1e-4, 1e-3, 1e-2]}
+run_experiment.entropy_coeff = 1e-4#{"grid_search": [1e-4, 1e-3, 1e-2]}
 run_experiment.kl_coeff = 0.2
 run_experiment.lambda_gae = 0.9
 run_experiment.predictor = None#@DummyPredictorForRailEnv()
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index cc8debe1b89a315624c99560b63858e66f2dea1e..cd25ad0d33a723922ff105bf1cdcfdaed283f3f1 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -2,7 +2,7 @@ import os
 
 import gin
 import gym
-from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from importlib_resources import path
 # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
 from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
@@ -11,6 +11,7 @@ from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
 from ray.rllib.models import ModelCatalog
 
 gin.external_configurable(DummyPredictorForRailEnv)
+gin.external_configurable(ShortestPathPredictorForRailEnv)
 
 import ray
 
@@ -66,6 +67,7 @@ def on_episode_end(info):
     score /= (len(episode._agent_reward_history) * 3 * episode.horizon)
     episode.custom_metrics["score"] = score
 
+
 def train(config, reporter):
     print('Init Env')
 
@@ -81,23 +83,12 @@ def train(config, reporter):
                   "seed": config['seed'],
                   "obs_builder": config['obs_builder'],
                   "min_dist": config['min_dist'],
-                  # "predictor": config["predictor"],
+                  "predictor": config["predictor"],
                   "step_memory": config["step_memory"]}
 
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
         obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)), ))
-                                      # gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
-                                      # gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
-                                      #    'step_memory'])
-        # if config['predictor'] is None:
-        #     obs_space = gym.spaces.Tuple(
-        #         (gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),) * config['step_memory'])
-        # else:
-        #     obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
-        #                                   gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
-        #                                   gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
-        #                                      'step_memory'])
         preprocessor = "tree_obs_prep"
 
     elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
@@ -152,7 +143,7 @@ def train(config, reporter):
     trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
                                     "policy_mapping_fn": policy_mapping_fn,
                                     "policies_to_train": list(policy_graphs.keys())}
-    trainer_config["horizon"] = 1.5 * (config['map_width'] + config['map_height'])#config['horizon']
+    trainer_config["horizon"] = 3 * (config['map_width'] + config['map_height'])#config['horizon']
 
     trainer_config["num_workers"] = 0
     trainer_config["num_cpus_per_worker"] = 7