diff --git a/.gitignore b/.gitignore
index 29b232caa8b74274c91f95feeaf06deb537e1d4f..907432162f5211ed2ea4d08f05c5a81b28c8646c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
 *pycache*
+*ppo_policy*
diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index ca2a78b3678f3fe989612d47264285dba7adcd38..a5184f2be9a02ea4dba2ef1d6381487f69feb541 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -4,7 +4,6 @@ from flatland.envs.observations import TreeObsForRailEnv
 from ray.rllib.utils.seed import seed as set_seed
 from flatland.envs.generators import complex_rail_generator, random_rail_generator
 import numpy as np
-from flatland.envs.predictions import DummyPredictorForRailEnv
 
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
@@ -34,7 +33,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         self.env = RailEnv(width=config["width"], height=config["height"],
                 number_of_agents=config["number_of_agents"],
                 obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
-                prediction_builder_object=DummyPredictorForRailEnv())
+                prediction_builder_object=config['predictor'])
 
         if self.predefined_env:
             self.env.load(config['load_env_path'])
@@ -42,6 +41,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
 
         self.width = self.env.width
         self.height = self.env.height
+        self.step_memory = config["step_memory"]
 
 
     
@@ -53,51 +53,60 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
             obs = self.env.reset()
 
         predictions = self.env.predict()
-        pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
+        if predictions != {}:
+            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
 
         o = dict()
 
         for i_agent in range(len(self.env.agents)):
+            
+            if predictions != {}:
+                # prediction of collision that will be added to the observation
+                # Allows to the agent to know which other train is is about to meet (maybe will come
+                # up with a priority order of trains).
+                pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
 
-            # prediction of collision that will be added to the observation
-            # Allows to the agent to know which other train is is about to meet (maybe will come
-            # up with a priority order of trains).
-            pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
+                for time_offset in range(len(predictions[0])):
 
-            for time_offset in range(len(predictions[0])):
+                    # We consider a time window of t-1; t+1 to find a collision
+                    collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
 
-                # We consider a time window of t-1; t+1 to find a collision
-                collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
+                    coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
 
-                coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
+                    # x coordinates of all other train in the time window
+                    x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
+                                                    :, collision_window, 0]
 
-                # x coordinates of all other train in the time window
-                x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
-                                                :, collision_window, 0]
+                    # y coordinates of all other train in the time window
+                    y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
+                                                    :, collision_window, 1]
 
-                # y coordinates of all other train in the time window
-                y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
-                                                :, collision_window, 1]
+                    coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
 
-                coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
+                    # collision_info here contains the index of the agent colliding with the current agent
+                    for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
+                        pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
 
-                # collision_info here contains the index of the agent colliding with the current agent
-                for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
-                    pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
+                agent_id_one_hot = np.zeros(len(self.env.agents))
+                agent_id_one_hot[i_agent] = 1
+                o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
+            else:
+                o[i_agent] = obs[i_agent]
+        
+        self.rail = self.env.rail                      
+        self.agents = self.env.agents                  
+        self.agents_static = self.env.agents_static    
+        self.dev_obs_dict = self.env.dev_obs_dict      
 
-            agent_id_one_hot = np.zeros(len(self.env.agents))
-            agent_id_one_hot[i_agent] = 1
-            o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
+        if self.step_memory < 2:
+            return o 
+        else:
+            self.old_obs = o
+            oo = dict()
 
-        self.old_obs = o
-        oo = dict()
-        for i_agent in range(len(self.env.agents)):
-            oo[i_agent] = [o[i_agent], o[i_agent][0], o[i_agent][1], o[i_agent][2]]
-        self.rail = self.env.rail
-        self.agents = self.env.agents
-        self.agents_static = self.env.agents_static
-        self.dev_obs_dict = self.env.dev_obs_dict
-        return oo
+            for i_agent in range(len(self.env.agents)):
+                oo[i_agent] = [o[i_agent], o[i_agent]]
+            return oo
 
     def step(self, action_dict):
         obs, rewards, dones, infos = self.env.step(action_dict)
@@ -109,39 +118,42 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         # print(self.agents_done)
         # print(dones)
         predictions = self.env.predict()
-        pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
+        if predictions != {}:    
+            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
 
         for i_agent in range(len(self.env.agents)):
             if i_agent not in self.agents_done:
                 # prediction of collision that will be added to the observation
                 # Allows to the agent to know which other train is is about to meet (maybe will come
                 # up with a priority order of trains).
-                pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
+                if predictions != {}:
+                    pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
+                    for time_offset in range(len(predictions[0])):
 
-                for time_offset in range(len(predictions[0])):
+                        # We consider a time window of t-1; t+1 to find a collision
+                        collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
 
-                    # We consider a time window of t-1; t+1 to find a collision
-                    collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
+                        coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
 
-                    coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
+                        # x coordinates of all other train in the time window
+                        x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
+                                                        :, collision_window, 0]
 
-                    # x coordinates of all other train in the time window
-                    x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
-                                                    :, collision_window, 0]
-
-                    # y coordinates of all other train in the time window
-                    y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
-                                                    :, collision_window, 1]
+                        # y coordinates of all other train in the time window
+                        y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
+                                                        :, collision_window, 1]
 
-                    coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
+                        coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
 
-                    # collision_info here contains the index of the agent colliding with the current agent
-                    for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
-                        pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
+                        # collision_info here contains the index of the agent colliding with the current agent
+                        for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
+                            pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
 
-                agent_id_one_hot = np.zeros(len(self.env.agents))
-                agent_id_one_hot[i_agent] = 1
-                o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
+                    agent_id_one_hot = np.zeros(len(self.env.agents))
+                    agent_id_one_hot[i_agent] = 1
+                    o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
+                else:
+                    o[i_agent] = obs[i_agent]
                 r[i_agent] = rewards[i_agent]
                 d[i_agent] = dones[i_agent]
 
@@ -162,13 +174,16 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         self.agents_static = self.env.agents_static
         self.dev_obs_dict = self.env.dev_obs_dict
         #print('Old OBS #####', self.old_obs)
-        oo = dict()
-        for i_agent in range(len(self.env.agents)):
-            if i_agent not in self.agents_done:
-                oo[i_agent] = [o[i_agent], self.old_obs[i_agent][0], self.old_obs[i_agent][1],
-                            self.old_obs[i_agent][2]]
         
-        self.old_obs = o
+        if self.step_memory >= 2:
+            oo = dict()
+
+            for i_agent in range(len(self.env.agents)):
+                if i_agent not in self.agents_done:
+                    oo[i_agent] = [o[i_agent], self.old_obs[i_agent]]
+        
+            self.old_obs = o
+        
         for agent, done in dones.items():
             if done and agent != '__all__':
                 self.agents_done.append(agent)
@@ -183,7 +198,10 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         # o['global_obs'] = np.ones((17, 17)) * 17
         # r['global_obs'] = 0
         # d['global_obs'] = True
-        return oo, r, d, infos
+        if self.step_memory < 2:
+            return o, r, d, infos
+        else:
+            return oo, r, d, infos
 
     def get_agent_handles(self):
         return self.env.get_agent_handles()
diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py
index 45e6937fabea10346b07d4923a2f6ed81a5280ac..1d1d214cef1af84b99c719359a859712559974bc 100644
--- a/RLLib_training/custom_preprocessors.py
+++ b/RLLib_training/custom_preprocessors.py
@@ -49,13 +49,15 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 
 class CustomPreprocessor(Preprocessor):
     def _init_shape(self, obs_space, options):
+        return (sum([space.shape[0] for space in obs_space]), )
         return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,)
 
     def transform(self, observation):
         # if len(observation) == 111:
-        return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
-        one_hot = observation[-3:]
-        return np.append(obs, one_hot)
+        return np.concatenate([norm_obs_clip(obs) for obs in observation])
+        #return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
+        #one_hot = observation[-3:]
+        #return np.append(obs, one_hot)
         # else:
         #     return observation
 
diff --git a/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin b/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
index 5ad08cdcb6aa98e59e4c7bbff603224021d4340f..82688fe897bc1160e8a03c73b176535599f8df9e 100644
--- a/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
+++ b/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
@@ -7,8 +7,8 @@ run_experiment.map_width = 15
 run_experiment.map_height = 15
 run_experiment.n_agents = 8
 run_experiment.rail_generator = "complex_rail_generator"
-run_experiment.nr_extra = {"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
-run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_map_width_{config[map_width]}_"
+run_experiment.nr_extra = 10#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
+run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_"
 
 run_experiment.horizon = 50
 run_experiment.seed = 123
diff --git a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
index 2c911fd134ddff11b696d2614d133645c7e901e3..258bc1d97ee321b6b61d17d2a010e7310a9351ca 100644
--- a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
+++ b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
@@ -8,19 +8,19 @@ run_experiment.map_height = 8
 run_experiment.n_agents = 3
 run_experiment.rail_generator = "complex_rail_generator"
 run_experiment.nr_extra = 5#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
-run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_kl_coeff_{config[kl_coeff]}_lambda_gae_{config[lambda_gae]}_horizon_{config[horizon]}_"
+run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_kl_coeff_{config[kl_coeff]}_horizon_{config[horizon]}_"
 
-run_experiment.horizon = {"grid_search": [30, 50]}
+run_experiment.horizon = {"grid_search": [50, 100]}
 run_experiment.seed = 123
 
 #run_experiment.conv_model = {"grid_search": [True, False]}
 run_experiment.conv_model = False
 
 #run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
-run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]}
+run_experiment.obs_builder = @TreeObsForRailEnv()
 TreeObsForRailEnv.max_depth = 2
 LocalObsForRailEnv.view_radius = 5
 
 run_experiment.entropy_coeff = 0.01
 run_experiment.kl_coeff = {"grid_search": [0, 0.2]}
-run_experiment.lambda_gae = {"grid_search": [0.9, 1.0]}
+run_experiment.lambda_gae = 0.9# {"grid_search": [0.9, 1.0]}
diff --git a/RLLib_training/experiment_configs/experiment_agent_memory/config.gin b/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
new file mode 100644
index 0000000000000000000000000000000000000000..58df080ecd624d0076f06b81b785bb4ac29d3139
--- /dev/null
+++ b/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
@@ -0,0 +1,27 @@
+run_experiment.name = "memory_experiment_results"
+run_experiment.num_iterations = 2002
+run_experiment.save_every = 50
+run_experiment.hidden_sizes = {"grid_search": [[32, 32], [64, 64], [128, 128]]}
+
+run_experiment.map_width = 8
+run_experiment.map_height = 8
+run_experiment.n_agents = 3
+run_experiment.rail_generator = "complex_rail_generator"
+run_experiment.nr_extra = 5
+run_experiment.policy_folder_name = "ppo_policy_hidden_size_{config[hidden_sizes][0]}_entropy_coeff_{config[entropy_coeff]}_"
+
+run_experiment.horizon = 50
+run_experiment.seed = 123
+
+#run_experiment.conv_model = {"grid_search": [True, False]}
+run_experiment.conv_model = False
+
+run_experiment.obs_builder = @TreeObsForRailEnv()
+TreeObsForRailEnv.max_depth = 2
+LocalObsForRailEnv.view_radius = 5
+
+run_experiment.entropy_coeff = {"grid_search": [1e-4, 1e-3, 1e-2]}
+run_experiment.kl_coeff = 0.2
+run_experiment.lambda_gae = 0.9
+run_experiment.predictor = None#@DummyPredictorForRailEnv()
+run_experiment.step_memory = 2
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index d7bf2614ab4fea523e1b68cc0bb3fafce7a3b3b6..4d3a2a92a33d888f7268f9686277b3f108149c7c 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -18,6 +18,9 @@ from baselines.RLLib_training.custom_preprocessors import CustomPreprocessor, Co
 
 from baselines.RLLib_training.custom_models import ConvModelGlobalObs
 
+from flatland.envs.predictions import DummyPredictorForRailEnv
+gin.external_configurable(DummyPredictorForRailEnv)
+
 
 import ray
 import numpy as np
@@ -57,16 +60,18 @@ def train(config, reporter):
                   "nr_extra": config["nr_extra"],
                   "number_of_agents": config['n_agents'],
                   "seed": config['seed'],
-                  "obs_builder": config['obs_builder']}
+                  "obs_builder": config['obs_builder'],
+                  "predictor": config["predictor"],
+                  "step_memory": config["step_memory"]}
 
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
-        obs_space = gym.spaces.Tuple((gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
-                                     gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
-                                     gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),
-                                     gym.spaces.Box(low=0, high=float('inf'), shape=(147,)),
-                                     gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
-                                     gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents']))))
+        if config['predictor'] is None:
+            obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), ) * config['step_memory'])
+        else:
+            obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
+                                        gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
+                                        gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) *config['step_memory'])
         preprocessor = "tree_obs_prep"
 
     elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
@@ -126,11 +131,11 @@ def train(config, reporter):
     trainer_config["horizon"] = config['horizon']
 
     trainer_config["num_workers"] = 0
-    trainer_config["num_cpus_per_worker"] = 11
-    trainer_config["num_gpus"] = 0.5
-    trainer_config["num_gpus_per_worker"] = 0.5
+    trainer_config["num_cpus_per_worker"] = 4
+    trainer_config["num_gpus"] = 0.2
+    trainer_config["num_gpus_per_worker"] = 0.2
     trainer_config["num_cpus_for_driver"] = 1
-    trainer_config["num_envs_per_worker"] = 6
+    trainer_config["num_envs_per_worker"] = 1
     trainer_config['entropy_coeff'] = config['entropy_coeff']
     trainer_config["env_config"] = env_config
     trainer_config["batch_mode"] = "complete_episodes"
@@ -170,7 +175,8 @@ def train(config, reporter):
 @gin.configurable
 def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                    map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
-                   entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae):
+                   entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
+                   predictor, step_memory):
 
     tune.run(
         train,
@@ -191,11 +197,13 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "rail_generator": rail_generator,
                 "nr_extra": nr_extra,
                 "kl_coeff": kl_coeff,
-                "lambda_gae": lambda_gae
+                "lambda_gae": lambda_gae,
+                "predictor": predictor,
+                "step_memory": step_memory
                 },
         resources_per_trial={
-            "cpu": 12,
-            "gpu": 0.5
+            "cpu": 5,
+            "gpu": 0.2
         },
         local_dir=local_dir
     )
@@ -203,6 +211,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
 
 if __name__ == '__main__':
     gin.external_configurable(tune.grid_search)
-    dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents'  # To Modify
+    dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/experiment_agent_memory'  # To Modify
     gin.parse_config_file(dir + '/config.gin')
     run_experiment(local_dir=dir)