From a54b734a86dc40fad0c66371f0e723200fe08d04 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Mon, 20 May 2019 10:02:50 +0200
Subject: [PATCH] observation benchmark script

---
 CustomPreprocessor.py                         |  8 +++++
 .../observation_benchmark/config.gin          | 16 ++++++++++
 train_experiment.py                           | 30 +++++++++++++++----
 3 files changed, 48 insertions(+), 6 deletions(-)
 create mode 100644 experiment_configs/observation_benchmark/config.gin

diff --git a/CustomPreprocessor.py b/CustomPreprocessor.py
index ce8cd9e..3a2e5c1 100644
--- a/CustomPreprocessor.py
+++ b/CustomPreprocessor.py
@@ -54,3 +54,11 @@ class CustomPreprocessor(Preprocessor):
 
     def transform(self, observation):
         return norm_obs_clip(observation)  # return the preprocessed observation
+
+
+
+# class NoPreprocessor:
+#     def _init_shape(self, obs_space, options):
+#         num_features = 0
+#         for space in obs_space:
+
diff --git a/experiment_configs/observation_benchmark/config.gin b/experiment_configs/observation_benchmark/config.gin
new file mode 100644
index 0000000..2d9e726
--- /dev/null
+++ b/experiment_configs/observation_benchmark/config.gin
@@ -0,0 +1,16 @@
+run_experiment.name = "n_agents_results"
+run_experiment.num_iterations = 1002
+run_experiment.save_every = 200
+run_experiment.hidden_sizes = [32, 32]
+
+run_experiment.map_width = 20
+run_experiment.map_height = 20
+run_experiment.n_agents = {"grid_search": [2, 5]}
+run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
+
+run_experiment.horizon = 50
+run_experiment.seed = 123
+
+run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv, @GlobalObsForRailEnv]}
+TreeObsForRailEnv.max_depth = 2
+
diff --git a/train_experiment.py b/train_experiment.py
index a70258b..b53bac6 100644
--- a/train_experiment.py
+++ b/train_experiment.py
@@ -24,9 +24,12 @@ import tempfile
 import gin
 
 from ray import tune
+
 from ray.rllib.utils.seed import seed as set_seed
+from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
+from ray.rllib.models.preprocessors import TupleFlatteningPreprocessor
 
-ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
+ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
 ray.init()
 
 
@@ -55,7 +58,22 @@ def train(config, reporter):
                   "seed": config['seed']}
 
     # Observation space and action space definitions
-    obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
+    if type(config["obs_builder"]) == TreeObsForRailEnv:
+        obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
+        preprocessor = "tree_obs_prep"
+
+    elif type(config["obs_builder"]) == GlobalObsForRailEnv:
+        obs_space = gym.spaces.Tuple((
+            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
+            gym.spaces.Box(low=0, high=1, shape=(4, config['map_height'], config['map_width'])),
+            gym.spaces.Space(4)))
+
+        preprocessor = TupleFlatteningPreprocessor
+
+    else:
+        raise ValueError("Undefined observation space")
+
+
     act_space = gym.spaces.Discrete(4)
 
     # Dict with the different policies to train
@@ -69,7 +87,7 @@ def train(config, reporter):
 
     # Trainer configuration
     trainer_config = DEFAULT_CONFIG.copy()
-    trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
+    trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor}
     trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
                                   "policy_mapping_fn": policy_mapping_fn,
                                   "policies_to_train": list(policy_graphs.keys())}
@@ -129,8 +147,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "seed": seed
                 },
         resources_per_trial={
-            "cpu": 1,
-            "gpu": 0.0
+            "cpu": 12,
+            "gpu": 0.5
         },
         local_dir=local_dir
     )
@@ -138,6 +156,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
 
 if __name__ == '__main__':
     gin.external_configurable(tune.grid_search)
-    dir = '/home/guillaume/Desktop/distMAgent/baselines/experiment_configs/n_agents_experiment'  # To Modify
+    dir = '/mount/SDC/flatland/baselines/experiment_configs/observation_benchmark'  # To Modify
     gin.parse_config_file(dir + '/config.gin')
     run_experiment(local_dir=dir)
-- 
GitLab