diff --git a/train_experiment.py b/train_experiment.py
index 68b45684bc4dcaf5efa190f731d03a436e94dbaa..0c1af1727abaaaf3078e24d7a250f071bda6cb9c 100644
--- a/train_experiment.py
+++ b/train_experiment.py
@@ -8,7 +8,8 @@ from flatland.envs.generators import complex_rail_generator
 # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
 from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
 from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
-from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
+# from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
+from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
 
 from ray.rllib.models import ModelCatalog
 from ray.tune.logger import pretty_print
@@ -34,7 +35,7 @@ from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
 
 ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
 ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor)
-ray.init(object_store_memory=150000000000, redis_max_memory=30000000000)
+ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000)
 
 
 def train(config, reporter):
@@ -62,9 +63,6 @@ def train(config, reporter):
                   "seed": config['seed'],
                   "obs_builder": config['obs_builder']}
 
-    print(config["obs_builder"])
-    print(config["obs_builder"].__class__)
-    print(type(TreeObsForRailEnv))
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
         obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
@@ -73,7 +71,8 @@ def train(config, reporter):
     elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
         obs_space = gym.spaces.Tuple((
             gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
-            gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])),
+            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 3)),
+            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 4)),
             gym.spaces.Box(low=0, high=1, shape=(4,))))
         preprocessor = "global_obs_prep"
 
@@ -101,14 +100,15 @@ def train(config, reporter):
     trainer_config["horizon"] = config['horizon']
 
     trainer_config["num_workers"] = 0
-    trainer_config["num_cpus_per_worker"] = 10
-    trainer_config["num_gpus"] = 0.5
-    trainer_config["num_gpus_per_worker"] = 0.5
-    trainer_config["num_cpus_for_driver"] = 2
-    trainer_config["num_envs_per_worker"] = 10
+    trainer_config["num_cpus_per_worker"] = 3
+    trainer_config["num_gpus"] = 0
+    trainer_config["num_gpus_per_worker"] = 0
+    trainer_config["num_cpus_for_driver"] = 1
+    trainer_config["num_envs_per_worker"] = 1
     trainer_config["env_config"] = env_config
     trainer_config["batch_mode"] = "complete_episodes"
-    trainer_config['simple_optimizer'] = False
+    trainer_config['simple_optimizer'] = True
+    trainer_config['postprocess_inputs'] = True
 
     def logger_creator(conf):
         """Creates a Unified logger with a default logdir prefix
@@ -155,8 +155,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "seed": seed
                 },
         resources_per_trial={
-            "cpu": 12,
-            "gpu": 0.5
+            "cpu": 2,
+            "gpu": 0.0
         },
         local_dir=local_dir
     )
@@ -164,6 +164,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
 
 if __name__ == '__main__':
     gin.external_configurable(tune.grid_search)
-    dir = '/mount/SDC/flatland/baselines/experiment_configs/observation_benchmark'  # To Modify
+    dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/experiment_configs/observation_benchmark'  # To Modify
     gin.parse_config_file(dir + '/config.gin')
     run_experiment(local_dir=dir)