diff --git a/.gitignore b/.gitignore
index 29b232caa8b74274c91f95feeaf06deb537e1d4f..907432162f5211ed2ea4d08f05c5a81b28c8646c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
 *pycache*
+*ppo_policy*
diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index c018683105fd502d5f6357d399f6679826c28f56..57fe38e481780ba52e5a179417b85ccef053a97b 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -1,93 +1,208 @@
 from flatland.envs.rail_env import RailEnv
 from ray.rllib.env.multi_agent_env import MultiAgentEnv
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.generators import random_rail_generator
 from ray.rllib.utils.seed import seed as set_seed
+from flatland.envs.generators import complex_rail_generator, random_rail_generator
 import numpy as np
 
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
 
     def __init__(self, config):
-                 # width,
-                 # height,
-                 # rail_generator=random_rail_generator(),
-                 # number_of_agents=1,
-                 # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+
         super(MultiAgentEnv, self).__init__()
         if hasattr(config, "vector_index"):
             vector_index = config.vector_index
         else:
             vector_index = 1
-        #self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
-         #                                              nr_extra=30, seed=config['seed'] * (1+vector_index))
+
+        self.predefined_env = False
+
+        if config['rail_generator'] == "complex_rail_generator":
+            self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
+                                                          nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index))
+        elif config['rail_generator'] == "random_rail_generator":
+            self.rail_generator = random_rail_generator()
+        elif config['rail_generator'] == "load_env":
+            self.predefined_env = True
+
+        else:
+            raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
+
         set_seed(config['seed'] * (1+vector_index))
-        #self.env = RailEnv(width=config["width"], height=config["height"],
-        self.env = RailEnv(width=10, height=20,
-                number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder'])
+        self.env = RailEnv(width=config["width"], height=config["height"],
+                number_of_agents=config["number_of_agents"],
+                obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
+                prediction_builder_object=config['predictor'])
 
-        self.env.load_resource('torch_training.railway', 'complex_scene.pkl')
+        if self.predefined_env:
+            self.env.load(config['load_env_path'])
+            self.env.load_resource('torch_training.railway', config['load_env_path'])
 
         self.width = self.env.width
         self.height = self.env.height
+        self.step_memory = config["step_memory"]
 
-
-    
     def reset(self):
         self.agents_done = []
-        obs = self.env.reset(False, False)
+        if self.predefined_env:
+            obs = self.env.reset(False, False)
+        else:
+            obs = self.env.reset()
+
+        predictions = self.env.predict()
+        if predictions != {}:
+            # pred_pos is a 3 dimensions array (N_Agents, T_pred, 2) containing x and y coordinates of
+            # agents at each time step
+            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
+            pred_dir = [x[:, 2] for x in list(predictions.values())]
+
         o = dict()
-        # o['agents'] = obs
-        # obs[0] = [obs[0], np.ones((17, 17)) * 17]
-        # obs['global_obs'] = np.ones((17, 17)) * 17
 
+        for i_agent in range(len(self.env.agents)):
+
+            if predictions != {}:
+                pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent)
+
+                agent_id_one_hot = np.zeros(len(self.env.agents))
+                agent_id_one_hot[i_agent] = 1
+                o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
+            else:
+                o[i_agent] = obs[i_agent]
 
+        # needed for the renderer
         self.rail = self.env.rail
         self.agents = self.env.agents
         self.agents_static = self.env.agents_static
         self.dev_obs_dict = self.env.dev_obs_dict
-        return obs
+
+        if self.step_memory < 2:
+            return o
+        else:
+            self.old_obs = o
+            oo = dict()
+
+            for i_agent in range(len(self.env.agents)):
+                oo[i_agent] = [o[i_agent], o[i_agent]]
+            return oo
 
     def step(self, action_dict):
         obs, rewards, dones, infos = self.env.step(action_dict)
-        # print(obs)
 
         d = dict()
         r = dict()
         o = dict()
-        # print(self.agents_done)
-        # print(dones)
-        for agent, done in dones.items():
-            if agent not in self.agents_done:
-                if agent != '__all__':
-                    o[agent] = obs[agent]
-                    r[agent] = rewards[agent]
-    
-                d[agent] = dones[agent]
+
+        predictions = self.env.predict()
+        if predictions != {}:
+            # pred_pos is a 3 dimensions array (N_Agents, T_pred, 2) containing x and y coordinates of
+            # agents at each time step
+            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
+            pred_dir = [x[:, 2] for x in list(predictions.values())]
+
+        for i_agent in range(len(self.env.agents)):
+            if i_agent not in self.agents_done:
+
+                if predictions != {}:
+                    pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent)
+                    agent_id_one_hot = np.zeros(len(self.env.agents))
+                    agent_id_one_hot[i_agent] = 1
+                    o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
+                else:
+                    o[i_agent] = obs[i_agent]
+                r[i_agent] = rewards[i_agent]
+                d[i_agent] = dones[i_agent]
+
+        d['__all__'] = dones['__all__']
+
+        if self.step_memory >= 2:
+            oo = dict()
+
+            for i_agent in range(len(self.env.agents)):
+                if i_agent not in self.agents_done:
+                    oo[i_agent] = [o[i_agent], self.old_obs[i_agent]]
+
+            self.old_obs = o
 
         for agent, done in dones.items():
             if done and agent != '__all__':
                 self.agents_done.append(agent)
 
-        self.rail = self.env.rail
-        self.agents = self.env.agents
-        self.agents_static = self.env.agents_static
-        self.dev_obs_dict = self.env.dev_obs_dict
-        
-        #print(obs)
-        #return obs, rewards, dones, infos
-        # oo = dict()
-        # oo['agents'] = o
-        # o['global'] = np.ones((17, 17)) * 17
-
-        # o[0] = [o[0], np.ones((17, 17)) * 17]
-        # o['global_obs'] = np.ones((17, 17)) * 17
-        # r['global_obs'] = 0
-        # d['global_obs'] = True
-        return o, r, d, infos
+        if self.step_memory < 2:
+            return o, r, d, infos
+        else:
+            return oo, r, d, infos
 
     def get_agent_handles(self):
         return self.env.get_agent_handles()
 
     def get_num_agents(self):
         return self.env.get_num_agents()
+
+    def get_prediction_as_observation(self, pred_pos, pred_dir, agent_handle):
+        '''
+        :param pred_pos: pred_pos should be a 3 dimensions array (N_Agents, T_pred, 2) containing x and y
+                         predicted coordinates of agents at each time step
+        :param pred_dir: pred_dir should be a 2 dimensions array (N_Agents, T_pred) predicted directions
+                         of agents at each time step
+        :param agent_handle: agent index
+        :return: 2 dimensional array (T_pred, N_agents) with value 1 at coord. (t,i) if agent 'agent_handle'
+                and agent i are going to meet at time step t.
+
+        Computes prediction of collision that will be added to the observation.
+        Allows to the agent to know which other train it is about to meet, and when.
+        The id of the other trains are shared, allowing eventually the agents to come
+        up with a priority order of trains.
+        '''
+
+        pred_obs = np.zeros((len(pred_pos[1]), len(self.env.agents)))
+
+        for time_offset in range(len(pred_pos[1])):
+
+            # We consider a time window of t-1:t+1 to find a collision
+            collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(pred_pos[1]))))
+
+            # coordinate of agent `agent_handle` at time t.
+            coord_agent = pred_pos[agent_handle, time_offset, 0] + 1000 * pred_pos[agent_handle, time_offset, 1]
+
+            # x coordinates of all other agents in the time window
+            # array of dim (N_Agents, 3), the 3 elements corresponding to x coordinates of the agents
+            # at t-1, t, t + 1
+            x_coord_other_agents = pred_pos[list(range(agent_handle)) +
+                                            list(range(agent_handle + 1,
+                                                       len(self.env.agents)))][:, collision_window, 0]
+
+            # y coordinates of all other agents in the time window
+            # array of dim (N_Agents, 3), the 3 elements corresponding to y coordinates of the agents
+            # at t-1, t, t + 1
+            y_coord_other_agents = pred_pos[list(range(agent_handle)) +
+                                            list(range(agent_handle + 1, len(self.env.agents)))][
+                                   :, collision_window, 1]
+
+            coord_other_agents = x_coord_other_agents + 1000 * y_coord_other_agents
+
+            # collision_info here contains the index of the agent colliding with the current agent and
+            # the delta_t at which they visit the same cell (0 for t-1, 1 for t or 2 for t+1)
+            for collision_info in np.argwhere(coord_agent == coord_other_agents):
+                # If they are on the same cell at the same time, there is a collison in all cases
+                if collision_info[1] == 1:
+                    pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
+                elif collision_info[1] == 0:
+                    # In this case, the other agent (agent 2) was on the same cell at t-1
+                    # There is a collision if agent 2 is at t, on the cell where was agent 1 at t-1
+                    coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset-1, 0] + \
+                                          1000 * pred_pos[agent_handle, time_offset, 1]
+                    coord_agent_2_t = coord_other_agents[collision_info[0], 1]
+                    if coord_agent_1_t_minus_1 == coord_agent_2_t:
+                        pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
+
+                elif collision_info[1] == 2:
+                    # In this case, the other agent (agent 2) will be on the same cell at t+1
+                    # There is a collision if agent 2 is at t, on the cell where will be agent 1 at t+1
+                    coord_agent_1_t_plus_1 = pred_pos[agent_handle, time_offset + 1, 0] + \
+                                              1000 * pred_pos[agent_handle, time_offset, 1]
+                    coord_agent_2_t = coord_other_agents[collision_info[0], 1]
+                    if coord_agent_1_t_plus_1 == coord_agent_2_t:
+                        pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
+
+        return pred_obs
diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py
index 1c3fa0898582a6f9d093dbcac787d70805b2e0b6..1d1d214cef1af84b99c719359a859712559974bc 100644
--- a/RLLib_training/custom_preprocessors.py
+++ b/RLLib_training/custom_preprocessors.py
@@ -1,7 +1,6 @@
 import numpy as np
 from ray.rllib.models.preprocessors import Preprocessor
 
-
 def max_lt(seq, val):
     """
     Return greatest item in seq for which item < val applies.
@@ -36,27 +35,31 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
     :param obs: Observation that should be normalized
     :param clip_min: min value where observation will be clipped
     :param clip_max: max value where observation will be clipped
-    :return: returns normalized and clipped observation
+    :return: returnes normalized and clipped observatoin
     """
     max_obs = max(1, max_lt(obs, 1000))
     min_obs = max(0, min_lt(obs, 0))
     if max_obs == min_obs:
-        return np.clip(np.array(obs)/ max_obs, clip_min, clip_max)
+        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
     norm = np.abs(max_obs - min_obs)
     if norm == 0:
         norm = 1.
-    return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max)
+    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
 
 
 class CustomPreprocessor(Preprocessor):
     def _init_shape(self, obs_space, options):
-        return (111,)
+        return (sum([space.shape[0] for space in obs_space]), )
+        return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0]*obs_space[2].shape[1])*2,)
 
     def transform(self, observation):
-        if len(observation) == 111:
-            return norm_obs_clip(observation)
-        else:
-            return observation
+        # if len(observation) == 111:
+        return np.concatenate([norm_obs_clip(obs) for obs in observation])
+        #return np.concatenate([norm_obs_clip(observation[0][0]), observation[0][1], observation[0][2].flatten(), norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
+        #one_hot = observation[-3:]
+        #return np.append(obs, one_hot)
+        # else:
+        #     return observation
 
 
 class ConvModelPreprocessor(Preprocessor):
diff --git a/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin b/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
new file mode 100644
index 0000000000000000000000000000000000000000..82688fe897bc1160e8a03c73b176535599f8df9e
--- /dev/null
+++ b/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
@@ -0,0 +1,25 @@
+run_experiment.name = "observation_benchmark_results"
+run_experiment.num_iterations = 1002
+run_experiment.save_every = 50
+run_experiment.hidden_sizes = [32, 32]
+
+run_experiment.map_width = 15
+run_experiment.map_height = 15
+run_experiment.n_agents = 8
+run_experiment.rail_generator = "complex_rail_generator"
+run_experiment.nr_extra = 10#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
+run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_"
+
+run_experiment.horizon = 50
+run_experiment.seed = 123
+
+#run_experiment.conv_model = {"grid_search": [True, False]}
+run_experiment.conv_model = False
+
+#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
+run_experiment.obs_builder = @TreeObsForRailEnv()
+TreeObsForRailEnv.max_depth = 2
+LocalObsForRailEnv.view_radius = 5
+
+run_experiment.entropy_coeff = 0.01
+
diff --git a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
new file mode 100644
index 0000000000000000000000000000000000000000..258bc1d97ee321b6b61d17d2a010e7310a9351ca
--- /dev/null
+++ b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
@@ -0,0 +1,26 @@
+run_experiment.name = "observation_benchmark_results"
+run_experiment.num_iterations = 2002
+run_experiment.save_every = 50
+run_experiment.hidden_sizes = [32, 32]
+
+run_experiment.map_width = 8
+run_experiment.map_height = 8
+run_experiment.n_agents = 3
+run_experiment.rail_generator = "complex_rail_generator"
+run_experiment.nr_extra = 5#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
+run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_kl_coeff_{config[kl_coeff]}_horizon_{config[horizon]}_"
+
+run_experiment.horizon = {"grid_search": [50, 100]}
+run_experiment.seed = 123
+
+#run_experiment.conv_model = {"grid_search": [True, False]}
+run_experiment.conv_model = False
+
+#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
+run_experiment.obs_builder = @TreeObsForRailEnv()
+TreeObsForRailEnv.max_depth = 2
+LocalObsForRailEnv.view_radius = 5
+
+run_experiment.entropy_coeff = 0.01
+run_experiment.kl_coeff = {"grid_search": [0, 0.2]}
+run_experiment.lambda_gae = 0.9# {"grid_search": [0.9, 1.0]}
diff --git a/RLLib_training/experiment_configs/experiment_agent_memory/__init__.py b/RLLib_training/experiment_configs/experiment_agent_memory/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/RLLib_training/experiment_configs/experiment_agent_memory/config.gin b/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
new file mode 100644
index 0000000000000000000000000000000000000000..58df080ecd624d0076f06b81b785bb4ac29d3139
--- /dev/null
+++ b/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
@@ -0,0 +1,27 @@
+run_experiment.name = "memory_experiment_results"
+run_experiment.num_iterations = 2002
+run_experiment.save_every = 50
+run_experiment.hidden_sizes = {"grid_search": [[32, 32], [64, 64], [128, 128]]}
+
+run_experiment.map_width = 8
+run_experiment.map_height = 8
+run_experiment.n_agents = 3
+run_experiment.rail_generator = "complex_rail_generator"
+run_experiment.nr_extra = 5
+run_experiment.policy_folder_name = "ppo_policy_hidden_size_{config[hidden_sizes][0]}_entropy_coeff_{config[entropy_coeff]}_"
+
+run_experiment.horizon = 50
+run_experiment.seed = 123
+
+#run_experiment.conv_model = {"grid_search": [True, False]}
+run_experiment.conv_model = False
+
+run_experiment.obs_builder = @TreeObsForRailEnv()
+TreeObsForRailEnv.max_depth = 2
+LocalObsForRailEnv.view_radius = 5
+
+run_experiment.entropy_coeff = {"grid_search": [1e-4, 1e-3, 1e-2]}
+run_experiment.kl_coeff = 0.2
+run_experiment.lambda_gae = 0.9
+run_experiment.predictor = None#@DummyPredictorForRailEnv()
+run_experiment.step_memory = 2
diff --git a/RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin b/RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin
index 64ff1c981dc9d068dee3a089bc8cb77c834d9e63..1369bb44d9e6f6d761a2d7f4a37af11a735c1fab 100644
--- a/RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin
+++ b/RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin
@@ -4,8 +4,8 @@ run_experiment.save_every = 50
 run_experiment.hidden_sizes = [32, 32]
 
 run_experiment.map_width = 20
-run_experiment.map_height = 20
-run_experiment.n_agents = 5
+run_experiment.map_height = 10
+run_experiment.n_agents = 8
 run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}"#_entropy_coeff_{config[entropy_coeff]}_{config[hidden_sizes][0]}_hidden_sizes_"
 
 run_experiment.horizon = 50
diff --git a/RLLib_training/render_training_result.py b/RLLib_training/render_training_result.py
index 5b9a08cfbfb57cf9b8b1486a3bf91ef575c8ce64..5f0159fa70e975b80a17a296a82139470a645f48 100644
--- a/RLLib_training/render_training_result.py
+++ b/RLLib_training/render_training_result.py
@@ -50,11 +50,11 @@ ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs)
 ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000)
 
 
-CHECKPOINT_PATH = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/RLLib_training/experiment_configs/' \
-                  'conv_model_test/ppo_policy_TreeObsForRailEnv_5_agents_conv_model_False_ial1g3w9/checkpoint_51/checkpoint-51'
+CHECKPOINT_PATH = '/home/guillaume/Desktop/distMAgent/env_complexity_benchmark/' \
+                  'ppo_policy_nr_extra_10_0qxx0qy_/checkpoint_1001/checkpoint-1001'
 
-N_EPISODES = 3
-N_STEPS_PER_EPISODE = 50
+N_EPISODES = 10
+N_STEPS_PER_EPISODE = 80
 
 
 def render_training_result(config):
@@ -62,22 +62,11 @@ def render_training_result(config):
 
     set_seed(config['seed'], config['seed'], config['seed'])
 
-    transition_probability = [15,  # empty cell - Case 0
-                              5,  # Case 1 - straight
-                              5,  # Case 2 - simple switch
-                              1,  # Case 3 - diamond crossing
-                              1,  # Case 4 - single slip
-                              1,  # Case 5 - double slip
-                              1,  # Case 6 - symmetrical
-                              0,  # Case 7 - dead end
-                              1,  # Case 1b (8)  - simple turn right
-                              1,  # Case 1c (9)  - simple turn left
-                              1]  # Case 2b (10) - simple switch mirrored
-
     # Example configuration to generate a random rail
     env_config = {"width": config['map_width'],
                   "height": config['map_height'],
-                  "rail_generator": complex_rail_generator,
+                  "rail_generator": config["rail_generator"],
+                  "nr_extra": config["nr_extra"],
                   "number_of_agents": config['n_agents'],
                   "seed": config['seed'],
                   "obs_builder": config['obs_builder']}
@@ -85,7 +74,7 @@ def render_training_result(config):
 
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
-        obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
+        obs_space = gym.spaces.Box(low=-1, high=1, shape=(147,))
         preprocessor = "tree_obs_prep"
 
     elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
@@ -154,6 +143,8 @@ def render_training_result(config):
     trainer_config['simple_optimizer'] = False
     trainer_config['postprocess_inputs'] = True
     trainer_config['log_level'] = 'WARN'
+    trainer_config['num_sgd_iter'] = 10
+    trainer_config['clip_param'] = 0.2
 
     env = RailEnvRLLibWrapper(env_config)
 
@@ -163,21 +154,29 @@ def render_training_result(config):
 
     policy = trainer.get_policy(config['policy_folder_name'].format(**locals()))
 
-    env_renderer = RenderTool(env, gl="PIL", show=True)
+
+    preprocessor = CustomPreprocessor(gym.spaces.Box(low=-1, high=1, shape=(147,)))
+    env_renderer = RenderTool(env, gl="PIL")
     for episode in range(N_EPISODES):
         observation = env.reset()
         for i in range(N_STEPS_PER_EPISODE):
-
-            action, _, infos = policy.compute_actions(list(observation.values()), [])
-            env_renderer.renderEnv(show=True, frames=True, iEpisode=episode, iStep=i,
-                                   action_dict=action)
+            preprocessed_obs = []
+            for obs in observation.values():
+                preprocessed_obs.append(preprocessor.transform(obs))
+            action, _, infos = policy.compute_actions(preprocessed_obs, [])
             logits = infos['behaviour_logits']
             actions = dict()
             for j, logit in enumerate(logits):
                 actions[j] = np.argmax(logit)
-
+            # for j, act in enumerate(action):
+                # actions[j] = act
             time.sleep(1)
-            observation, _, _, _ = env.step(action)
+            print(actions, logits)
+            # print(action, print(infos['behaviour_logits']))
+            env_renderer.renderEnv(show=True, frames=True, iEpisode=episode, iStep=i,
+                                   action_dict=list(actions.values()))
+
+            observation, _, _, _ = env.step(actions)
 
     env_renderer.close_window()
 
@@ -185,7 +184,7 @@ def render_training_result(config):
 @gin.configurable
 def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                    map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
-                   entropy_coeff, seed, conv_model):
+                   entropy_coeff, seed, conv_model, rail_generator, nr_extra):
 
     render_training_result(
         config={"n_agents": n_agents,
@@ -199,12 +198,15 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "obs_builder": obs_builder,
                 "entropy_coeff": entropy_coeff,
                 "seed": seed,
-                "conv_model": conv_model
-                })
+                "conv_model": conv_model,
+                "rail_generator": rail_generator,
+                "nr_extra": 10# nr_extra
+                }
+    )
 
 
 if __name__ == '__main__':
     gin.external_configurable(tune.grid_search)
-    dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/RLLib_training/experiment_configs/conv_model_test'  # To Modify
+    dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/RLLib_training/experiment_configs/env_complexity_benchmark'  # To Modify
     gin.parse_config_file(dir + '/config.gin')
     run_experiment(local_dir=dir)
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index 26c8c5bad70d07b83af407bf22c94910debdd2c3..d7d0b4a71f1abbb62134805b9e7bb41580c1f84e 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -3,6 +3,11 @@ import tempfile
 
 import gin
 import gym
+
+import gin
+
+from flatland.envs.generators import complex_rail_generator
+
 import ray
 from importlib_resources import path
 from ray import tune
@@ -12,6 +17,18 @@ from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
 from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
 from ray.rllib.models import ModelCatalog
 from ray.rllib.utils.seed import seed as set_seed
+from ray.tune.logger import pretty_print
+from baselines.RLLib_training.custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
+
+from baselines.RLLib_training.custom_models import ConvModelGlobalObs
+
+from flatland.envs.predictions import DummyPredictorForRailEnv
+gin.external_configurable(DummyPredictorForRailEnv)
+
+
+import ray
+import numpy as np
+
 from ray.tune.logger import UnifiedLogger
 from ray.tune.logger import pretty_print
 
@@ -21,6 +38,13 @@ from custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \
     LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
+import tempfile
+
+from ray import tune
+
+from ray.rllib.utils.seed import seed as set_seed
+from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv,\
+                                       LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
 
 gin.external_configurable(TreeObsForRailEnv)
 gin.external_configurable(GlobalObsForRailEnv)
@@ -43,21 +67,25 @@ def train(config, reporter):
 
     set_seed(config['seed'], config['seed'], config['seed'])
 
-    config['map_width'] = 20
-    config['map_height'] = 10
-    config['n_agents'] = 8
-
     # Example configuration to generate a random rail
     env_config = {"width": config['map_width'],
                   "height": config['map_height'],
-                  "rail_generator": complex_rail_generator,
+                  "rail_generator": config["rail_generator"],
+                  "nr_extra": config["nr_extra"],
                   "number_of_agents": config['n_agents'],
                   "seed": config['seed'],
-                  "obs_builder": config['obs_builder']}
+                  "obs_builder": config['obs_builder'],
+                  "predictor": config["predictor"],
+                  "step_memory": config["step_memory"]}
 
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
-        obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(111,))
+        if config['predictor'] is None:
+            obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), ) * config['step_memory'])
+        else:
+            obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
+                                        gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
+                                        gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) *config['step_memory'])
         preprocessor = "tree_obs_prep"
 
     elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
@@ -92,7 +120,8 @@ def train(config, reporter):
     else:
         raise ValueError("Undefined observation space")
 
-    act_space = gym.spaces.Discrete(4)
+
+    act_space = gym.spaces.Discrete(5)
 
     # Dict with the different policies to train
     policy_graphs = {
@@ -115,9 +144,9 @@ def train(config, reporter):
     trainer_config["horizon"] = config['horizon']
 
     trainer_config["num_workers"] = 0
-    trainer_config["num_cpus_per_worker"] = 3
-    trainer_config["num_gpus"] = 0
-    trainer_config["num_gpus_per_worker"] = 0
+    trainer_config["num_cpus_per_worker"] = 4
+    trainer_config["num_gpus"] = 0.2
+    trainer_config["num_gpus_per_worker"] = 0.2
     trainer_config["num_cpus_for_driver"] = 1
     trainer_config["num_envs_per_worker"] = 1
     trainer_config['entropy_coeff'] = config['entropy_coeff']
@@ -126,6 +155,10 @@ def train(config, reporter):
     trainer_config['simple_optimizer'] = False
     trainer_config['postprocess_inputs'] = True
     trainer_config['log_level'] = 'WARN'
+    trainer_config['num_sgd_iter'] = 10
+    trainer_config['clip_param'] = 0.2
+    trainer_config['kl_coeff'] = config['kl_coeff']
+    trainer_config['lambda'] = config['lambda_gae']
 
     def logger_creator(conf):
         """Creates a Unified logger with a default logdir prefix
@@ -155,7 +188,9 @@ def train(config, reporter):
 @gin.configurable
 def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                    map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
-                   entropy_coeff, seed, conv_model):
+                   entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
+                   predictor, step_memory):
+
     tune.run(
         train,
         name=name,
@@ -171,11 +206,17 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "obs_builder": obs_builder,
                 "entropy_coeff": entropy_coeff,
                 "seed": seed,
-                "conv_model": conv_model
+                "conv_model": conv_model,
+                "rail_generator": rail_generator,
+                "nr_extra": nr_extra,
+                "kl_coeff": kl_coeff,
+                "lambda_gae": lambda_gae,
+                "predictor": predictor,
+                "step_memory": step_memory
                 },
         resources_per_trial={
-            "cpu": 2,
-            "gpu": 0.0
+            "cpu": 5,
+            "gpu": 0.2
         },
         verbose=2,
         local_dir=local_dir
@@ -184,8 +225,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
 
 if __name__ == '__main__':
     gin.external_configurable(tune.grid_search)
-    with path('RLLib_training.experiment_configs.observation_benchmark_loaded_env', 'config.gin') as f:
+    with path('RLLib_training.experiment_configs.experiment_agent_memory', 'config.gin') as f:
         gin.parse_config_file(f)
 
-    dir = os.path.join(__file_dirname__, 'experiment_configs', 'observation_benchmark_loaded_env')
+    dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory')
     run_experiment(local_dir=dir)
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 252ff7ce97b6fe9276c8629f17222f8274fff0a1..e673941b462550c94831646c40b39c692972ca45 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -5,6 +5,7 @@ from collections import deque
 import numpy as np
 import torch
 
+from dueling_double_dqn import Agent
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
@@ -41,17 +42,17 @@ env = RailEnv(width=15,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
               number_of_agents=1)
 
-"""
+
 
 env = RailEnv(width=10,
               height=20)
 env.load_resource('torch_training.railway', "complex_scene.pkl")
-
-env = RailEnv(width=15,
-              height=15,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
+"""
+env = RailEnv(width=8,
+              height=8,
+              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0),
               number_of_agents=1)
-env.reset(False, False)
+env.reset(True, True)
 
 env_renderer = RenderTool(env, gl="PILSVG")
 handle = env.get_agent_handles()
@@ -73,11 +74,10 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
+#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
 
 demo = False
 
-
 def max_lt(seq, val):
     """
     Return greatest item in seq for which item < val applies.
@@ -138,7 +138,7 @@ for trials in range(1, n_trials + 1):
                                                                 current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
-
+        agent_data = np.clip(agent_data, -1, 1)
         obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
     for i in range(2):
         time_obs.append(obs)
@@ -149,7 +149,7 @@ for trials in range(1, n_trials + 1):
     score = 0
     env_done = 0
     # Run episode
-    for step in range(360):
+    for step in range(100):
         if demo:
             env_renderer.renderEnv(show=True, show_observations=False)
         # print(step)
@@ -160,14 +160,15 @@ for trials in range(1, n_trials + 1):
             action = agent.act(agent_obs[a], eps=eps)
             action_prob[action] += 1
             action_dict.update({a: action})
-
         # Environment step
+
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.get_num_agents()):
             data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7,
                                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
+            agent_data = np.clip(agent_data, -1, 1)
             next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
 
         time_obs.append(next_obs)
@@ -175,7 +176,6 @@ for trials in range(1, n_trials + 1):
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
-
             if done[a]:
                 final_obs[a] = agent_obs[a].copy()
                 final_obs_next[a] = agent_next_obs[a].copy()
@@ -217,4 +217,4 @@ for trials in range(1, n_trials + 1):
                 action_prob / np.sum(action_prob)))
         torch.save(agent.qnetwork_local.state_dict(),
                    os.path.join(__file_dirname__, 'Nets', 'avoid_checkpoint' + str(trials) + '.pth'))
-        action_prob = [1] * 4
+        action_prob = [1] * action_size