diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py
index da54ad01f6d5536c0aab129509ae501f93e5ead9..fbdee61432f0aa2f558f543c002e9846570dc140 100644
--- a/RailEnvRLLibWrapper.py
+++ b/RailEnvRLLibWrapper.py
@@ -3,6 +3,8 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.generators import random_rail_generator
 from ray.rllib.utils.seed import seed as set_seed
+import numpy as np
+
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
 
@@ -13,15 +15,20 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
                  # number_of_agents=1,
                  # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
         super(MultiAgentEnv, self).__init__()
-        self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5, nr_extra=30,
-                                                       seed=config['seed'] * (1+config.vector_index))
+        self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
+                                                       nr_extra=30, seed=config['seed'] * (1+config.vector_index))
         set_seed(config['seed'] * (1+config.vector_index))
         self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
                 number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder'])
     
     def reset(self):
         self.agents_done = []
-        return self.env.reset()
+        obs = self.env.reset()
+        o = dict()
+        # o['agents'] = obs
+        # obs[0] = [obs[0], np.ones((17, 17)) * 17]
+        # obs['global_obs'] = np.ones((17, 17)) * 17
+        return obs
 
     def step(self, action_dict):
         obs, rewards, dones, infos = self.env.step(action_dict)
@@ -46,7 +53,15 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         
         #print(obs)
         #return obs, rewards, dones, infos
+        # oo = dict()
+        # oo['agents'] = o
+        # o['global'] = np.ones((17, 17)) * 17
+
+        # o[0] = [o[0], np.ones((17, 17)) * 17]
+        # o['global_obs'] = np.ones((17, 17)) * 17
+        # r['global_obs'] = 0
+        # d['global_obs'] = True
         return o, r, d, infos
-    
+
     def get_agent_handles(self):
         return self.env.get_agent_handles()
diff --git a/experiment_configs/CustomModels.py b/experiment_configs/CustomModels.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/experiment_configs/observation_benchmark/config.gin b/experiment_configs/observation_benchmark/config.gin
index 9f3c0727dd6b5a922b2a8b212e4bf5e6f77f0dab..f5a4dc80396e26476296a4d8b83cb4882a0f1033 100644
--- a/experiment_configs/observation_benchmark/config.gin
+++ b/experiment_configs/observation_benchmark/config.gin
@@ -5,12 +5,13 @@ run_experiment.hidden_sizes = [32, 32]
 
 run_experiment.map_width = 20
 run_experiment.map_height = 20
-run_experiment.n_agents = {"grid_search": [2, 5]}
+run_experiment.n_agents = 5
 run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_{config[n_agents]}_agents"
 
 run_experiment.horizon = 50
 run_experiment.seed = 123
 
-run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]}
+run_experiment.obs_builder = {"grid_search": [@LocalObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
 TreeObsForRailEnv.max_depth = 2
+LocalObsForRailEnv.view_radius = 5
 
diff --git a/train_experiment.py b/train_experiment.py
index 0c1af1727abaaaf3078e24d7a250f071bda6cb9c..16c52b657db58c6f69e8cf91004d4259461be2cd 100644
--- a/train_experiment.py
+++ b/train_experiment.py
@@ -8,8 +8,9 @@ from flatland.envs.generators import complex_rail_generator
 # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
 from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
 from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
-# from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
-from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
+# from baselines.CustomPPOTrainer import PPOTrainer as Trainer
+from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
+# from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
 
 from ray.rllib.models import ModelCatalog
 from ray.tune.logger import pretty_print
@@ -27,9 +28,10 @@ import gin
 from ray import tune
 
 from ray.rllib.utils.seed import seed as set_seed
-from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, LocalObsForRailEnv
 gin.external_configurable(TreeObsForRailEnv)
 gin.external_configurable(GlobalObsForRailEnv)
+gin.external_configurable(LocalObsForRailEnv)
 
 from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
 
@@ -76,6 +78,15 @@ def train(config, reporter):
             gym.spaces.Box(low=0, high=1, shape=(4,))))
         preprocessor = "global_obs_prep"
 
+    elif isinstance(config["obs_builder"], LocalObsForRailEnv):
+        view_radius = config["obs_builder"].view_radius
+        obs_space = gym.spaces.Tuple((
+            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 16)),
+            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 2)),
+            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 4)),
+            gym.spaces.Box(low=0, high=1, shape=(4,))))
+        preprocessor = "global_obs_prep"
+
     else:
         raise ValueError("Undefined observation space")
 
@@ -107,8 +118,9 @@ def train(config, reporter):
     trainer_config["num_envs_per_worker"] = 1
     trainer_config["env_config"] = env_config
     trainer_config["batch_mode"] = "complete_episodes"
-    trainer_config['simple_optimizer'] = True
+    trainer_config['simple_optimizer'] = False
     trainer_config['postprocess_inputs'] = True
+    trainer_config['log_level'] = 'WARN'
 
     def logger_creator(conf):
         """Creates a Unified logger with a default logdir prefix