From 1ed4e72257f9a051916be9ece27e86c789460eba Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Fri, 28 Jun 2019 18:09:16 +0200
Subject: [PATCH] last training set up before freeze

---
 RLLib_training/RailEnvRLLibWrapper.py         | 115 ++--------------
 RLLib_training/custom_preprocessors.py        |  33 +----
 .../conv_model_test/config.gin                |  23 ----
 .../entropy_coeff_benchmark/config.gin        |  19 ---
 .../env_complexity_benchmark/config.gin       |  25 ----
 .../env_size_benchmark_3_agents/config.gin    |  30 ----
 .../experiment_agent_memory/config.gin        |  27 ----
 .../n_agents_experiment/config.gin            |  19 ---
 .../observation_benchmark/config.gin          |  17 ---
 .../config.gin                                |  19 ---
 .../predictions_test/config.gin               |  28 ----
 RLLib_training/render_training_result.py      | 130 +++++-------------
 RLLib_training/train_experiment.py            |  88 ++++--------
 13 files changed, 75 insertions(+), 498 deletions(-)
 delete mode 100644 RLLib_training/experiment_configs/conv_model_test/config.gin
 delete mode 100644 RLLib_training/experiment_configs/entropy_coeff_benchmark/config.gin
 delete mode 100644 RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
 delete mode 100644 RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
 delete mode 100644 RLLib_training/experiment_configs/experiment_agent_memory/config.gin
 delete mode 100644 RLLib_training/experiment_configs/n_agents_experiment/config.gin
 delete mode 100644 RLLib_training/experiment_configs/observation_benchmark/config.gin
 delete mode 100644 RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin
 delete mode 100644 RLLib_training/experiment_configs/predictions_test/config.gin

diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index 0fd0f59..d065063 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -5,13 +5,13 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
 from ray.rllib.utils.seed import seed as set_seed
 
 
-
-
 class RailEnvRLLibWrapper(MultiAgentEnv):
 
     def __init__(self, config):
 
         super(MultiAgentEnv, self).__init__()
+
+        # Environment ID if num_envs_per_worker > 1
         if hasattr(config, "vector_index"):
             vector_index = config.vector_index
         else:
@@ -24,12 +24,12 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
                                                          min_dist=config['min_dist'],
                                                          nr_extra=config['nr_extra'],
                                                          seed=config['seed'] * (1 + vector_index))
+
         elif config['rail_generator'] == "random_rail_generator":
             self.rail_generator = random_rail_generator()
         elif config['rail_generator'] == "load_env":
             self.predefined_env = True
             self.rail_generator = random_rail_generator()
-
         else:
             raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
 
@@ -39,13 +39,18 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
                            obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator)
 
         if self.predefined_env:
-            # self.env.load(config['load_env_path'])
             self.env.load_resource('torch_training.railway', 'complex_scene.pkl')
 
         self.width = self.env.width
         self.height = self.env.height
         self.step_memory = config["step_memory"]
 
+        # needed for the renderer
+        self.rail = self.env.rail
+        self.agents = self.env.agents
+        self.agents_static = self.env.agents_static
+        self.dev_obs_dict = self.env.dev_obs_dict
+
     def reset(self):
         self.agents_done = []
         if self.predefined_env:
@@ -53,27 +58,12 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         else:
             obs = self.env.reset()
 
-
-        # predictions = self.env.predict()
-        # if predictions != {}:
-        #     # pred_pos is a 3 dimensions array (N_Agents, T_pred, 2) containing x and y coordinates of
-        #     # agents at each time step
-        #     pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
-        #     pred_dir = [x[:, 2] for x in list(predictions.values())]
-
+        # RLLib only receives observation of agents that are not done.
         o = dict()
 
         for i_agent in range(len(self.env.agents)):
             data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
                                                                          num_features_per_node=8, current_depth=0)
-            # if predictions != {}:
-            #     pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent)
-            #
-            #     agent_id_one_hot = np.zeros(len(self.env.agents))
-            #     agent_id_one_hot[i_agent] = 1
-            #     o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
-            # else:
-
             o[i_agent] = [data, distance, agent_data]
 
         # needed for the renderer
@@ -82,6 +72,10 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         self.agents_static = self.env.agents_static
         self.dev_obs_dict = self.env.dev_obs_dict
 
+
+
+        # If step_memory > 1, we need to concatenate it the observations in memory, only works for
+        # step_memory = 1 or 2 for the moment
         if self.step_memory < 2:
             return o
         else:
@@ -99,24 +93,11 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         r = dict()
         o = dict()
 
-        # predictions = self.env.predict()
-        # if predictions != {}:
-        #     # pred_pos is a 3 dimensions array (N_Agents, T_pred, 2) containing x and y coordinates of
-        #     # agents at each time step
-        #     pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
-        #     pred_dir = [x[:, 2] for x in list(predictions.values())]
-
         for i_agent in range(len(self.env.agents)):
             if i_agent not in self.agents_done:
                 data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
                                                                              num_features_per_node=8, current_depth=0)
 
-                # if predictions != {}:
-                #     pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent)
-                #     agent_id_one_hot = np.zeros(len(self.env.agents))
-                #     agent_id_one_hot[i_agent] = 1
-                #     o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
-                # else:
                 o[i_agent] = [data, distance, agent_data]
                 r[i_agent] = rewards[i_agent]
                 d[i_agent] = dones[i_agent]
@@ -146,71 +127,3 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
 
     def get_num_agents(self):
         return self.env.get_num_agents()
-
-    def get_prediction_as_observation(self, pred_pos, pred_dir, agent_handle):
-        '''
-        :param pred_pos: pred_pos should be a 3 dimensions array (N_Agents, T_pred, 2) containing x and y
-                         predicted coordinates of agents at each time step
-        :param pred_dir: pred_dir should be a 2 dimensions array (N_Agents, T_pred) predicted directions
-                         of agents at each time step
-        :param agent_handle: agent index
-        :return: 2 dimensional array (T_pred, N_agents) with value 1 at coord. (t,i) if agent 'agent_handle'
-                and agent i are going to meet at time step t.
-
-        Computes prediction of collision that will be added to the observation.
-        Allows to the agent to know which other train it is about to meet, and when.
-        The id of the other trains are shared, allowing eventually the agents to come
-        up with a priority order of trains.
-        '''
-
-        pred_obs = np.zeros((len(pred_pos[1]), len(self.env.agents)))
-
-        for time_offset in range(len(pred_pos[1])):
-
-            # We consider a time window of t-1:t+1 to find a collision
-            collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(pred_pos[1]))))
-
-            # coordinate of agent `agent_handle` at time t.
-            coord_agent = pred_pos[agent_handle, time_offset, 0] + 1000 * pred_pos[agent_handle, time_offset, 1]
-
-            # x coordinates of all other agents in the time window
-            # array of dim (N_Agents, 3), the 3 elements corresponding to x coordinates of the agents
-            # at t-1, t, t + 1
-            x_coord_other_agents = pred_pos[list(range(agent_handle)) +
-                                            list(range(agent_handle + 1,
-                                                       len(self.env.agents)))][:, collision_window, 0]
-
-            # y coordinates of all other agents in the time window
-            # array of dim (N_Agents, 3), the 3 elements corresponding to y coordinates of the agents
-            # at t-1, t, t + 1
-            y_coord_other_agents = pred_pos[list(range(agent_handle)) +
-                                            list(range(agent_handle + 1, len(self.env.agents)))][
-                                   :, collision_window, 1]
-
-            coord_other_agents = x_coord_other_agents + 1000 * y_coord_other_agents
-
-            # collision_info here contains the index of the agent colliding with the current agent and
-            # the delta_t at which they visit the same cell (0 for t-1, 1 for t or 2 for t+1)
-            for collision_info in np.argwhere(coord_agent == coord_other_agents):
-                # If they are on the same cell at the same time, there is a collison in all cases
-                if collision_info[1] == 1:
-                    pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
-                elif collision_info[1] == 0:
-                    # In this case, the other agent (agent 2) was on the same cell at t-1
-                    # There is a collision if agent 2 is at t, on the cell where was agent 1 at t-1
-                    coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset - 1, 0] + \
-                                              1000 * pred_pos[agent_handle, time_offset, 1]
-                    coord_agent_2_t = coord_other_agents[collision_info[0], 1]
-                    if coord_agent_1_t_minus_1 == coord_agent_2_t:
-                        pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
-
-                elif collision_info[1] == 2:
-                    # In this case, the other agent (agent 2) will be on the same cell at t+1
-                    # There is a collision if agent 2 is at t, on the cell where will be agent 1 at t+1
-                    coord_agent_1_t_plus_1 = pred_pos[agent_handle, time_offset + 1, 0] + \
-                                             1000 * pred_pos[agent_handle, time_offset, 1]
-                    coord_agent_2_t = coord_other_agents[collision_info[0], 1]
-                    if coord_agent_1_t_plus_1 == coord_agent_2_t:
-                        pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
-
-        return pred_obs
diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py
index 2f236f3..6d93aea 100644
--- a/RLLib_training/custom_preprocessors.py
+++ b/RLLib_training/custom_preprocessors.py
@@ -48,11 +48,9 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
     return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
 
 
-class CustomPreprocessor(Preprocessor):
+class TreeObsPreprocessor(Preprocessor):
     def _init_shape(self, obs_space, options):
         return sum([space.shape[0] for space in obs_space]),
-        # return (sum([space.shape[0] for space in obs_space]), )
-        # return ((sum([space.shape[0] for space in obs_space[:2]]) + obs_space[2].shape[0] * obs_space[2].shape[1]),)
 
     def transform(self, observation):
         data = norm_obs_clip(observation[0][0])
@@ -63,33 +61,4 @@ class CustomPreprocessor(Preprocessor):
         agent_data2 = np.clip(observation[1][2], -1, 1)
 
         return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2))))
-        return norm_obs_clip(observation)
-        return np.concatenate([norm_obs_clip(observation[0]), norm_obs_clip(observation[1])])
-        # if len(observation) == 111:
-        # return np.concatenate([norm_obs_clip(obs) for obs in observation])
-        # print('OBSERVATION:', observation, len(observation[0]))
-        return np.concatenate([norm_obs_clip(observation[0]), observation[1], observation[
-            2].flatten()])  #, norm_obs_clip(observation[1]), observation[2], observation[3].flatten()])
-        #one_hot = observation[-3:]
-        #return np.append(obs, one_hot)
-        # else:
-        #     return observation
-
-
-class ConvModelPreprocessor(Preprocessor):
-    def _init_shape(self, obs_space, options):
-        out_shape = (obs_space[0].shape[0], obs_space[0].shape[1], sum([space.shape[2] for space in obs_space]))
-        return out_shape
-
-    def transform(self, observation):
-        return np.concatenate([observation[0],
-                               observation[1],
-                               observation[2]], axis=2)
-
-
-
-# class NoPreprocessor:
-#     def _init_shape(self, obs_space, options):
-#         num_features = 0
-#         for space in obs_space:
 
diff --git a/RLLib_training/experiment_configs/conv_model_test/config.gin b/RLLib_training/experiment_configs/conv_model_test/config.gin
deleted file mode 100644
index 3c923ca..0000000
--- a/RLLib_training/experiment_configs/conv_model_test/config.gin
+++ /dev/null
@@ -1,23 +0,0 @@
-run_experiment.name = "observation_benchmark_results"
-run_experiment.num_iterations = 1002
-run_experiment.save_every = 50
-run_experiment.hidden_sizes = [32, 32]
-
-run_experiment.map_width = 20
-run_experiment.map_height = 20
-run_experiment.n_agents = 5
-run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_{config[n_agents]}_agents_conv_model_{config[conv_model]}_"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-
-#run_experiment.conv_model = {"grid_search": [True, False]}
-run_experiment.conv_model = False
-
-#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
-run_experiment.obs_builder = @TreeObsForRailEnv()
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
-
-run_experiment.entropy_coeff = 0.01
-
diff --git a/RLLib_training/experiment_configs/entropy_coeff_benchmark/config.gin b/RLLib_training/experiment_configs/entropy_coeff_benchmark/config.gin
deleted file mode 100644
index e674447..0000000
--- a/RLLib_training/experiment_configs/entropy_coeff_benchmark/config.gin
+++ /dev/null
@@ -1,19 +0,0 @@
-run_experiment.name = "observation_benchmark_results"
-run_experiment.num_iterations = 1002
-run_experiment.save_every = 100
-run_experiment.hidden_sizes = {"grid_search": [[32, 32], [64, 64], [128, 128], [256, 256]}
-
-run_experiment.map_width = 20
-run_experiment.map_height = 20
-run_experiment.n_agents = 5
-run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_entropy_coeff_{config[entropy_coeff]}_{config[hidden_sizes][0]}_hidden_sizes_"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-
-run_experiment.entropy_coeff = {"grid_search": [1e-3, 1e-2, 0]}
-
-run_experiment.obs_builder = {"grid_search": [@LocalObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
-
diff --git a/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin b/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
deleted file mode 100644
index 82688fe..0000000
--- a/RLLib_training/experiment_configs/env_complexity_benchmark/config.gin
+++ /dev/null
@@ -1,25 +0,0 @@
-run_experiment.name = "observation_benchmark_results"
-run_experiment.num_iterations = 1002
-run_experiment.save_every = 50
-run_experiment.hidden_sizes = [32, 32]
-
-run_experiment.map_width = 15
-run_experiment.map_height = 15
-run_experiment.n_agents = 8
-run_experiment.rail_generator = "complex_rail_generator"
-run_experiment.nr_extra = 10#{"grid_search": [0, 5, 10, 20, 30, 40, 50, 60]}
-run_experiment.policy_folder_name = "ppo_policy_nr_extra_{config[nr_extra]}_"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-
-#run_experiment.conv_model = {"grid_search": [True, False]}
-run_experiment.conv_model = False
-
-#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
-run_experiment.obs_builder = @TreeObsForRailEnv()
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
-
-run_experiment.entropy_coeff = 0.01
-
diff --git a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin b/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
deleted file mode 100644
index a4ea715..0000000
--- a/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin
+++ /dev/null
@@ -1,30 +0,0 @@
-run_experiment.name = "observation_benchmark_results"
-run_experiment.num_iterations = 2002
-run_experiment.save_every = 100
-run_experiment.hidden_sizes = [32, 32]
-
-run_experiment.map_width = 40
-run_experiment.map_height = 40
-run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]}
-run_experiment.rail_generator = "complex_rail_generator"
-run_experiment.nr_extra = 5
-run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}__map_size_{config[map_width]}"
-
-#run_experiment.horizon =
-run_experiment.seed = 123
-
-#run_experiment.conv_model = {"grid_search": [True, False]}
-run_experiment.conv_model = False
-
-#run_experiment.obs_builder = {"grid_search": [@GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
-run_experiment.obs_builder = @TreeObsForRailEnv()
-TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv()
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
-
-run_experiment.entropy_coeff = 0.001
-run_experiment.kl_coeff = 0.2 #{"grid_search": [0, 0.2]}
-run_experiment.lambda_gae = 0.9 # {"grid_search": [0.9, 1.0]}
-#run_experiment.predictor = "ShortestPathPredictorForRailEnv()"
-run_experiment.step_memory = 2
-run_experiment.min_dist = 10
diff --git a/RLLib_training/experiment_configs/experiment_agent_memory/config.gin b/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
deleted file mode 100644
index 4de0800..0000000
--- a/RLLib_training/experiment_configs/experiment_agent_memory/config.gin
+++ /dev/null
@@ -1,27 +0,0 @@
-run_experiment.name = "memory_experiment_results"
-run_experiment.num_iterations = 2002
-run_experiment.save_every = 50
-run_experiment.hidden_sizes = [32, 32]#{"grid_search": [[32, 32], [64, 64], [128, 128]]}
-
-run_experiment.map_width = 8
-run_experiment.map_height = 8
-run_experiment.n_agents = 3
-run_experiment.rail_generator = "complex_rail_generator"
-run_experiment.nr_extra = 5
-run_experiment.policy_folder_name = "ppo_policy_hidden_size_{config[hidden_sizes][0]}_entropy_coeff_{config[entropy_coeff]}_"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-
-#run_experiment.conv_model = {"grid_search": [True, False]}
-run_experiment.conv_model = False
-
-run_experiment.obs_builder = @TreeObsForRailEnv()
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
-
-run_experiment.entropy_coeff = 1e-4#{"grid_search": [1e-4, 1e-3, 1e-2]}
-run_experiment.kl_coeff = 0.2
-run_experiment.lambda_gae = 0.9
-run_experiment.predictor = None#@DummyPredictorForRailEnv()
-run_experiment.step_memory = 2
diff --git a/RLLib_training/experiment_configs/n_agents_experiment/config.gin b/RLLib_training/experiment_configs/n_agents_experiment/config.gin
deleted file mode 100644
index 025eab9..0000000
--- a/RLLib_training/experiment_configs/n_agents_experiment/config.gin
+++ /dev/null
@@ -1,19 +0,0 @@
-run_experiment.name = "observation_benchmark_results"
-run_experiment.num_iterations = 1002
-run_experiment.save_every = 100
-run_experiment.hidden_sizes = [32,32]
-
-run_experiment.map_width = 20
-run_experiment.map_height = 20
-run_experiment.n_agents = {"grid_search": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
-run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_entropy_coeff_{config[entropy_coeff]}_{config[n_agents]}_agents_"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-
-run_experiment.entropy_coeff = {"grid_search": [1e-3, 1e-2, 0]}
-
-run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
-
diff --git a/RLLib_training/experiment_configs/observation_benchmark/config.gin b/RLLib_training/experiment_configs/observation_benchmark/config.gin
deleted file mode 100644
index f5a4dc8..0000000
--- a/RLLib_training/experiment_configs/observation_benchmark/config.gin
+++ /dev/null
@@ -1,17 +0,0 @@
-run_experiment.name = "observation_benchmark_results"
-run_experiment.num_iterations = 1002
-run_experiment.save_every = 100
-run_experiment.hidden_sizes = [32, 32]
-
-run_experiment.map_width = 20
-run_experiment.map_height = 20
-run_experiment.n_agents = 5
-run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}_{config[n_agents]}_agents"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-
-run_experiment.obs_builder = {"grid_search": [@LocalObsForRailEnv()]}# [@TreeObsForRailEnv(), @GlobalObsForRailEnv() ]}
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
-
diff --git a/RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin b/RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin
deleted file mode 100644
index 1369bb4..0000000
--- a/RLLib_training/experiment_configs/observation_benchmark_loaded_env/config.gin
+++ /dev/null
@@ -1,19 +0,0 @@
-run_experiment.name = "observation_benchmark_loaded_env_results"
-run_experiment.num_iterations = 1002
-run_experiment.save_every = 50
-run_experiment.hidden_sizes = [32, 32]
-
-run_experiment.map_width = 20
-run_experiment.map_height = 10
-run_experiment.n_agents = 8
-run_experiment.policy_folder_name = "ppo_policy_{config[obs_builder].__class__.__name__}"#_entropy_coeff_{config[entropy_coeff]}_{config[hidden_sizes][0]}_hidden_sizes_"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-run_experiment.conv_model = False
-
-run_experiment.entropy_coeff = 1e-2
-
-run_experiment.obs_builder = @TreeObsForRailEnv()#{"grid_search": [@LocalObsForRailEnv(), @TreeObsForRailEnv(), @GlobalObsForRailEnv(), @GlobalObsForRailEnvDirectionDependent()]}
-TreeObsForRailEnv.max_depth = 2
-LocalObsForRailEnv.view_radius = 5
diff --git a/RLLib_training/experiment_configs/predictions_test/config.gin b/RLLib_training/experiment_configs/predictions_test/config.gin
deleted file mode 100644
index b5923df..0000000
--- a/RLLib_training/experiment_configs/predictions_test/config.gin
+++ /dev/null
@@ -1,28 +0,0 @@
-run_experiment.name = "memory_experiment_results"
-run_experiment.num_iterations = 2002
-run_experiment.save_every = 50
-run_experiment.hidden_sizes = [32, 32]
-
-run_experiment.map_width = {"grid_search": [8, 10, 12, 14]}
-run_experiment.map_height = 8
-run_experiment.n_agents = 3
-run_experiment.rail_generator = "complex_rail_generator"
-run_experiment.nr_extra = 1
-run_experiment.policy_folder_name = "ppo_policy_with_pred_map_size_{config[map_width]}"
-
-run_experiment.horizon = 50
-run_experiment.seed = 123
-
-#run_experiment.conv_model = {"grid_search": [True, False]}
-run_experiment.conv_model = False
-
-run_experiment.obs_builder = @TreeObsForRailEnv()
-TreeObsForRailEnv.max_depth = 2
-TreeObsForRailEnv.predictor = @DummyPredictorForRailEnv()
-LocalObsForRailEnv.view_radius = 5
-
-run_experiment.entropy_coeff = 1e-3
-run_experiment.kl_coeff = 0.2
-run_experiment.lambda_gae = 0.9
-#run_experiment.predictor = "dummy_predictor"
-run_experiment.step_memory = 1
diff --git a/RLLib_training/render_training_result.py b/RLLib_training/render_training_result.py
index 1719f44..021b9c4 100644
--- a/RLLib_training/render_training_result.py
+++ b/RLLib_training/render_training_result.py
@@ -1,57 +1,39 @@
-from baselines.RLLib_training.RailEnvRLLibWrapper import RailEnvRLLibWrapper
+from RailEnvRLLibWrapper import RailEnvRLLibWrapper
+from custom_preprocessors import TreeObsPreprocessor
 import gym
+import os
 
-
-from flatland.envs.generators import complex_rail_generator
-
-
-# Import PPO trainer: we can replace these imports by any other trainer from RLLib.
 from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
 from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
-# from baselines.CustomPPOTrainer import PPOTrainer as Trainer
 from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
-# from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
 
 from ray.rllib.models import ModelCatalog
-from ray.tune.logger import pretty_print
-from baselines.RLLib_training.custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
-
-from baselines.RLLib_training.custom_models import ConvModelGlobalObs
-
 
 import ray
 import numpy as np
 
-from ray.tune.logger import UnifiedLogger
-import tempfile
-
 import gin
 
-from ray import tune
+from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
+gin.external_configurable(DummyPredictorForRailEnv)
+gin.external_configurable(ShortestPathPredictorForRailEnv)
 
 from ray.rllib.utils.seed import seed as set_seed
-from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv,\
-                                       LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
+from flatland.envs.observations import TreeObsForRailEnv
 
 from flatland.utils.rendertools import RenderTool
 import time
 
 gin.external_configurable(TreeObsForRailEnv)
-gin.external_configurable(GlobalObsForRailEnv)
-gin.external_configurable(LocalObsForRailEnv)
-gin.external_configurable(GlobalObsForRailEnvDirectionDependent)
 
-from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
+ModelCatalog.register_custom_preprocessor("tree_obs_prep", TreeObsPreprocessor)
+ray.init()  # object_store_memory=150000000000, redis_max_memory=30000000000)
 
-ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
-ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor)
-ModelCatalog.register_custom_preprocessor("conv_obs_prep", ConvModelPreprocessor)
-ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs)
-ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000)
+__file_dirname__ = os.path.dirname(os.path.realpath(__file__))
 
-
-CHECKPOINT_PATH = '/home/guillaume/Desktop/distMAgent/experiment_agent_memory/' \
-                  'ppo_policy_hidden_size_32_entropy_coeff_0.0001_mu413rlu/checkpoint_201/checkpoint-201'
+CHECKPOINT_PATH = os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'ppo_policy_two_obs_with_predictions_n_agents_4_map_size_20q58l5_f7',
+                               'checkpoint_101', 'checkpoint-101')
+CHECKPOINT_PATH = '/home/guillaume/Desktop/distMAgent/ppo_policy_two_obs_with_predictions_n_agents_7_8e5zko1_/checkpoint_1301/checkpoint-1301'
 
 N_EPISODES = 10
 N_STEPS_PER_EPISODE = 50
@@ -65,54 +47,18 @@ def render_training_result(config):
     # Example configuration to generate a random rail
     env_config = {"width": config['map_width'],
                   "height": config['map_height'],
-                  "rail_generator": "load_env",#config["rail_generator"],
+                  "rail_generator": config["rail_generator"],
                   "nr_extra": config["nr_extra"],
                   "number_of_agents": config['n_agents'],
                   "seed": config['seed'],
                   "obs_builder": config['obs_builder'],
-                  "predictor": config["predictor"],
+                  "min_dist": config['min_dist'],
                   "step_memory": config["step_memory"]}
 
     # Observation space and action space definitions
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
-        if config['predictor'] is None:
-            obs_space = gym.spaces.Tuple(
-                (gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),) * config['step_memory'])
-        else:
-            obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
-                                          gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
-                                          gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
-                                             'step_memory'])
-        preprocessor = "tree_obs_prep"
-
-    elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
-        obs_space = gym.spaces.Tuple((
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 8)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2))))
-        if config['conv_model']:
-            preprocessor = "conv_obs_prep"
-        else:
-            preprocessor = "global_obs_prep"
-
-    elif isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent):
-        obs_space = gym.spaces.Tuple((
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 5)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2))))
-        if config['conv_model']:
-            preprocessor = "conv_obs_prep"
-        else:
-            preprocessor = "global_obs_prep"
-
-    elif isinstance(config["obs_builder"], LocalObsForRailEnv):
-        view_radius = config["obs_builder"].view_radius
-        obs_space = gym.spaces.Tuple((
-            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 16)),
-            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 2)),
-            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 4)),
-            gym.spaces.Box(low=0, high=1, shape=(4,))))
-        preprocessor = "global_obs_prep"
+        obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2)
+        preprocessor = TreeObsPreprocessor
 
     else:
         raise ValueError("Undefined observation space")
@@ -122,23 +68,20 @@ def render_training_result(config):
     # Dict with the different policies to train
     policy_graphs = {
         config['policy_folder_name'].format(**locals()): (PolicyGraph, obs_space, act_space, {})
+        # "ppo_policy": (PolicyGraph, obs_space, act_space, {})
     }
 
     def policy_mapping_fn(agent_id):
-        return config['policy_folder_name'].format(**locals())
+        return "ppo_policy"
 
     # Trainer configuration
     trainer_config = DEFAULT_CONFIG.copy()
-    if config['conv_model']:
-        trainer_config['model'] = {"custom_model": "conv_model", "custom_preprocessor": preprocessor}
-    else:
-        trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor}
+
+    trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes']}
 
     trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
                                     "policy_mapping_fn": policy_mapping_fn,
                                     "policies_to_train": list(policy_graphs.keys())}
-    trainer_config["horizon"] = config['horizon']
-
 
     trainer_config["num_workers"] = 0
     trainer_config["num_cpus_per_worker"] = 4
@@ -161,15 +104,15 @@ def render_training_result(config):
 
     trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config)
 
-    print('hidden sizes:', config['hidden_sizes'])
     trainer.restore(CHECKPOINT_PATH)
 
+    # policy = trainer.get_policy("ppo_policy")
     policy = trainer.get_policy(config['policy_folder_name'].format(**locals()))
 
-
-    preprocessor = CustomPreprocessor(obs_space)
+    preprocessor = preprocessor(obs_space)
     env_renderer = RenderTool(env, gl="PIL")
     for episode in range(N_EPISODES):
+
         observation = env.reset()
         for i in range(N_STEPS_PER_EPISODE):
             preprocessed_obs = []
@@ -178,13 +121,18 @@ def render_training_result(config):
             action, _, infos = policy.compute_actions(preprocessed_obs, [])
             logits = infos['behaviour_logits']
             actions = dict()
+
+            # We select the greedy action.
             for j, logit in enumerate(logits):
                 actions[j] = np.argmax(logit)
+
+            # In case we prefer to sample an action stochastically according to the policy graph.
             # for j, act in enumerate(action):
                 # actions[j] = act
+
+            # Time to see the rendering at one step
             time.sleep(1)
-            print(actions, logits)
-            # print(action, print(infos['behaviour_logits']))
+
             env_renderer.renderEnv(show=True, frames=True, iEpisode=episode, iStep=i,
                                    action_dict=list(actions.values()))
 
@@ -195,9 +143,9 @@ def render_training_result(config):
 
 @gin.configurable
 def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
-                   map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
-                   entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff,
-                   lambda_gae, predictor, step_memory):
+                   map_width, map_height, policy_folder_name, obs_builder,
+                   entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
+                   step_memory, min_dist):
 
     render_training_result(
         config={"n_agents": n_agents,
@@ -205,8 +153,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "save_every": save_every,
                 "map_width": map_width,
                 "map_height": map_height,
-                "local_dir": local_dir,
-                "horizon": horizon,  # Max number of time steps
                 'policy_folder_name': policy_folder_name,
                 "obs_builder": obs_builder,
                 "entropy_coeff": entropy_coeff,
@@ -216,14 +162,12 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "nr_extra": nr_extra,
                 "kl_coeff": kl_coeff,
                 "lambda_gae": lambda_gae,
-                "predictor": predictor,
+                "min_dist": min_dist,
                 "step_memory": step_memory
                 }
     )
 
 
 if __name__ == '__main__':
-    gin.external_configurable(tune.grid_search)
-    dir = '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/RLLib_training/experiment_configs/experiment_agent_memory'  # To Modify
-    gin.parse_config_file(dir + '/config.gin')
-    run_experiment(local_dir=dir)
+    gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'config.gin'))
+    run_experiment()
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index 4b14dcf..1bac614 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -3,7 +3,7 @@ import os
 import gin
 import gym
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
-from importlib_resources import path
+
 # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
 from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
 from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
@@ -17,30 +17,22 @@ import ray
 
 from ray.tune.logger import UnifiedLogger
 from ray.tune.logger import pretty_print
+import os
 
 from RailEnvRLLibWrapper import RailEnvRLLibWrapper
-from custom_models import ConvModelGlobalObs
-from custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
 import tempfile
 
 from ray import tune
 
 from ray.rllib.utils.seed import seed as set_seed
-from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \
-    LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
+from flatland.envs.observations import TreeObsForRailEnv
 
 gin.external_configurable(TreeObsForRailEnv)
-gin.external_configurable(GlobalObsForRailEnv)
-gin.external_configurable(LocalObsForRailEnv)
-gin.external_configurable(GlobalObsForRailEnvDirectionDependent)
 
-from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
 import numpy as np
+from custom_preprocessors import TreeObsPreprocessor
 
-ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
-ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor)
-ModelCatalog.register_custom_preprocessor("conv_obs_prep", ConvModelPreprocessor)
-ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs)
+ModelCatalog.register_custom_preprocessor("tree_obs_prep", TreeObsPreprocessor)
 ray.init()  # object_store_memory=150000000000, redis_max_memory=30000000000)
 
 __file_dirname__ = os.path.dirname(os.path.realpath(__file__))
@@ -86,72 +78,44 @@ def train(config, reporter):
     if isinstance(config["obs_builder"], TreeObsForRailEnv):
         obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2)
         preprocessor = "tree_obs_prep"
-
-    elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
-        obs_space = gym.spaces.Tuple((
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 8)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2))))
-        if config['conv_model']:
-            preprocessor = "conv_obs_prep"
-        else:
-            preprocessor = "global_obs_prep"
-
-    elif isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent):
-        obs_space = gym.spaces.Tuple((
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 5)),
-            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2))))
-        if config['conv_model']:
-            preprocessor = "conv_obs_prep"
-        else:
-            preprocessor = "global_obs_prep"
-
-    elif isinstance(config["obs_builder"], LocalObsForRailEnv):
-        view_radius = config["obs_builder"].view_radius
-        obs_space = gym.spaces.Tuple((
-            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 16)),
-            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 2)),
-            gym.spaces.Box(low=0, high=1, shape=(2 * view_radius + 1, 2 * view_radius + 1, 4)),
-            gym.spaces.Box(low=0, high=1, shape=(4,))))
-        preprocessor = "global_obs_prep"
-
     else:
-        raise ValueError("Undefined observation space")
+        raise ValueError("Undefined observation space") # Only TreeObservation implemented for now.
 
     act_space = gym.spaces.Discrete(5)
 
-    # Dict with the different policies to train
+    # Dict with the different policies to train. In this case, all trains follow the same policy
     policy_graphs = {
-        config['policy_folder_name'].format(**locals()): (PolicyGraph, obs_space, act_space, {})
+        "ppo_policy": (PolicyGraph, obs_space, act_space, {})
     }
 
+    # Function that maps an agent id to the name of its respective policy.
     def policy_mapping_fn(agent_id):
-        return config['policy_folder_name'].format(**locals())
+        return "ppo_policy"
 
     # Trainer configuration
     trainer_config = DEFAULT_CONFIG.copy()
-    if config['conv_model']:
-        trainer_config['model'] = {"custom_model": "conv_model", "custom_preprocessor": preprocessor}
-    else:
-        trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor}
+    trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor}
 
     trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
                                     "policy_mapping_fn": policy_mapping_fn,
                                     "policies_to_train": list(policy_graphs.keys())}
-    trainer_config["horizon"] = 3 * (config['map_width'] + config['map_height'])#config['horizon']
 
+    # Maximum time steps for an episode is set to 3*map_width*map_height
+    trainer_config["horizon"] = 3 * (config['map_width'] + config['map_height'])
+
+    # Parameters for calculation parallelization
     trainer_config["num_workers"] = 0
-    trainer_config["num_cpus_per_worker"] = 7
+    trainer_config["num_cpus_per_worker"] = 3
     trainer_config["num_gpus"] = 0.0
     trainer_config["num_gpus_per_worker"] = 0.0
     trainer_config["num_cpus_for_driver"] = 1
     trainer_config["num_envs_per_worker"] = 1
+
+    # Parameters for PPO training
     trainer_config['entropy_coeff'] = config['entropy_coeff']
     trainer_config["env_config"] = env_config
     trainer_config["batch_mode"] = "complete_episodes"
     trainer_config['simple_optimizer'] = False
-    trainer_config['postprocess_inputs'] = True
     trainer_config['log_level'] = 'WARN'
     trainer_config['num_sgd_iter'] = 10
     trainer_config['clip_param'] = 0.2
@@ -163,9 +127,7 @@ def train(config, reporter):
         }
 
     def logger_creator(conf):
-        """Creates a Unified logger with a default logdir prefix
-        containing the agent name and the env id
-        """
+        """Creates a Unified logger with a default logdir prefix."""
         logdir = config['policy_folder_name'].format(**locals())
         logdir = tempfile.mkdtemp(
             prefix=logdir, dir=config['local_dir'])
@@ -212,11 +174,10 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "kl_coeff": kl_coeff,
                 "lambda_gae": lambda_gae,
                 "min_dist": min_dist,
-                # "predictor": predictor,
                 "step_memory": step_memory
                 },
         resources_per_trial={
-            "cpu": 8,
+            "cpu": 3,
             "gpu": 0
         },
         verbose=2,
@@ -225,10 +186,7 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
 
 
 if __name__ == '__main__':
-    gin.external_configurable(tune.grid_search)
-    # with path('RLLib_training.experiment_configs.n_agents_experiment', 'config.gin') as f:
-    #     gin.parse_config_file(f)
-    gin.parse_config_file('/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents/config.gin')
-    dir = '/mount/SDC/flatland/baselines/RLLib_training/experiment_configs/env_size_benchmark_3_agents'
-    # dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory')
+    print(str(os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'config.gin')))
+    gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'config.gin'))
+    dir = os.path.join(__file_dirname__, 'experiment_configs', 'config_example')
     run_experiment(local_dir=dir)
-- 
GitLab