From 4306c5993e5b116c055132342577cb799b292276 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume@iccluster091.iccluster.epfl.ch>
Date: Wed, 15 May 2019 14:36:58 +0200
Subject: [PATCH] grid search trainer works correctly

---
 CustomPreprocessor.py                         | 56 +++++++++++++++++
 RailEnvRLLibWrapper.py                        | 15 +++--
 .../n_agents_grid_search/config.gin           |  7 ++-
 grid_search_train.py                          | 60 +++++++++++--------
 train.py                                      | 17 ++----
 5 files changed, 109 insertions(+), 46 deletions(-)
 create mode 100644 CustomPreprocessor.py

diff --git a/CustomPreprocessor.py b/CustomPreprocessor.py
new file mode 100644
index 0000000..ce8cd9e
--- /dev/null
+++ b/CustomPreprocessor.py
@@ -0,0 +1,56 @@
+import numpy as np
+from ray.rllib.models.preprocessors import Preprocessor
+
+
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    max = 0
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
+            max = seq[idx]
+        idx -= 1
+    return max
+
+
+def min_lt(seq, val):
+    """
+    Return smallest item in seq for which item > val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    min = np.inf
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] > val and seq[idx] < min:
+            min = seq[idx]
+        idx -= 1
+    return min
+
+
+def norm_obs_clip(obs, clip_min=-1, clip_max=1):
+    """
+    This function returns the difference between min and max value of an observation
+    :param obs: Observation that should be normalized
+    :param clip_min: min value where observation will be clipped
+    :param clip_max: max value where observation will be clipped
+    :return: returnes normalized and clipped observatoin
+    """
+    max_obs = max(1, max_lt(obs, 1000))
+    min_obs = max(0, min_lt(obs, 0))
+    if max_obs == min_obs:
+        return np.clip(np.array(obs)/ max_obs, clip_min, clip_max)
+    norm = np.abs(max_obs - min_obs)
+    if norm == 0:
+        norm = 1.
+    return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max)
+
+
+class CustomPreprocessor(Preprocessor):
+    def _init_shape(self, obs_space, options):
+        return (105,)
+
+    def transform(self, observation):
+        return norm_obs_clip(observation)  # return the preprocessed observation
diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py
index fef562d..f6697cc 100644
--- a/RailEnvRLLibWrapper.py
+++ b/RailEnvRLLibWrapper.py
@@ -4,7 +4,7 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.envs.generators import random_rail_generator
 
 
-class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv):
+class RailEnvRLLibWrapper(MultiAgentEnv):
 
     def __init__(self, config):
                  # width,
@@ -12,16 +12,19 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv):
                  # rail_generator=random_rail_generator(),
                  # number_of_agents=1,
                  # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+        super(MultiAgentEnv, self).__init__()
 
-        super(RailEnvRLLibWrapper, self).__init__(width=config["width"], height=config["height"], rail_generator=config["rail_generator"],
+        self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=config["rail_generator"],
                 number_of_agents=config["number_of_agents"])
 
-    def reset(self, regen_rail=True, replace_agents=True):
+    def reset(self):
         self.agents_done = []
-        return super(RailEnvRLLibWrapper, self).reset(regen_rail, replace_agents)
+        return self.env.reset()
 
     def step(self, action_dict):
-        obs, rewards, dones, infos = super(RailEnvRLLibWrapper, self).step(action_dict)
+        obs, rewards, dones, infos = self.env.step(action_dict)
+        # print(obs)
+
         d = dict()
         r = dict()
         o = dict()
@@ -44,4 +47,4 @@ class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv):
         return o, r, d, infos
     
     def get_agent_handles(self):
-        return super(RailEnvRLLibWrapper, self).get_agent_handles()
+        return self.env.get_agent_handles()
diff --git a/grid_search_configs/n_agents_grid_search/config.gin b/grid_search_configs/n_agents_grid_search/config.gin
index 9830838..ab3d76e 100644
--- a/grid_search_configs/n_agents_grid_search/config.gin
+++ b/grid_search_configs/n_agents_grid_search/config.gin
@@ -3,9 +3,10 @@ run_grid_search.num_iterations = 1002
 run_grid_search.save_every = 200
 run_grid_search.hidden_sizes = [32, 32]
 
-run_grid_search.map_width = 50
-run_grid_search.map_height = 50
-run_grid_search.n_agents = {"grid_search": [2, 5, 10, 20]}
+run_grid_search.map_width = 20
+run_grid_search.map_height = 20
+run_grid_search.n_agents = {"grid_search": [1, 2, 5, 10]}
+run_grid_search.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
 
 run_grid_search.horizon = 50
 
diff --git a/grid_search_train.py b/grid_search_train.py
index 0f04d1c..1d06bea 100644
--- a/grid_search_train.py
+++ b/grid_search_train.py
@@ -10,28 +10,22 @@ from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
 
 from ray.rllib.models import ModelCatalog
 from ray.tune.logger import pretty_print
-from ray.rllib.models.preprocessors import Preprocessor
+from baselines.CustomPreprocessor import CustomPreprocessor
 
 
 import ray
 import numpy as np
 
+from ray.tune.logger import UnifiedLogger
+import tempfile
+
 import gin
 
 from ray import tune
 
 
-class MyPreprocessorClass(Preprocessor):
-    def _init_shape(self, obs_space, options):
-        return (105,)
-
-    def transform(self, observation):
-        return observation  # return the preprocessed observation
-
-
-ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
-ray.init()
-
+ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
+ray.init(object_store_memory=150000000000)
 
 def train(config, reporter):
     print('Init Env')
@@ -57,7 +51,7 @@ def train(config, reporter):
     """
     env_config = {"width":config['map_width'],
                   "height":config['map_height'],
-                  "rail_generator":complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
+                  "rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0),
                   "number_of_agents":config['n_agents']}
     """
     env = RailEnv(width=20,
@@ -80,11 +74,11 @@ def train(config, reporter):
 
     # Dict with the different policies to train
     policy_graphs = {
-        f"ppo_policy": (PPOPolicyGraph, obs_space, act_space, {})
+        config['policy_folder_name'].format(**locals()): (PPOPolicyGraph, obs_space, act_space, {})
     }
 
     def policy_mapping_fn(agent_id):
-        return f"ppo_policy"
+        return config['policy_folder_name'].format(**locals())
 
     agent_config = ppo.DEFAULT_CONFIG.copy()
     agent_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
@@ -93,16 +87,29 @@ def train(config, reporter):
                                   "policies_to_train": list(policy_graphs.keys())}
     agent_config["horizon"] = config['horizon']
 
-    # agent_config["num_workers"] = 0
-    # agent_config["num_cpus_per_worker"] = 10
-    # agent_config["num_gpus"] = 0.5
-    # agent_config["num_gpus_per_worker"] = 0.5
-    # agent_config["num_cpus_for_driver"] = 1
-    # agent_config["num_envs_per_worker"] = 10
+    agent_config["num_workers"] = 0
+    agent_config["num_cpus_per_worker"] = 10
+    agent_config["num_gpus"] = 0.5
+    agent_config["num_gpus_per_worker"] = 0.5
+    agent_config["num_cpus_for_driver"] = 2
+    agent_config["num_envs_per_worker"] = 10
     agent_config["env_config"] = env_config
     agent_config["batch_mode"] = "complete_episodes"
+    agent_config['simple_optimizer'] = False
+
+    def logger_creator(conf):
+        """Creates a Unified logger with a default logdir prefix
+        containing the agent name and the env id
+        """
+        print("FOLDER", config['policy_folder_name'])
+        logdir = config['policy_folder_name'].format(**locals())
+        logdir = tempfile.mkdtemp(
+            prefix=logdir, dir=config['local_dir'])
+        return UnifiedLogger(conf, logdir, None)
+
+    logger = logger_creator
 
-    ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config)
+    ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config, logger_creator=logger)
 
     for i in range(100000 + 2):
         print("== Iteration", i, "==")
@@ -119,7 +126,7 @@ def train(config, reporter):
 
 @gin.configurable
 def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
-                    map_width, map_height, horizon, local_dir):
+                    map_width, map_height, horizon, policy_folder_name, local_dir):
 
     tune.run(
         train,
@@ -131,10 +138,11 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
                 "map_width": map_width,
                 "map_height": map_height,
                 "local_dir": local_dir,
-                "horizon": horizon  # Max number of time steps
+                "horizon": horizon,  # Max number of time steps
+                'policy_folder_name': policy_folder_name
                 },
         resources_per_trial={
-            "cpu": 11,
+            "cpu": 12,
             "gpu": 0.5
         },
         local_dir=local_dir
@@ -143,6 +151,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
 
 if __name__ == '__main__':
     gin.external_configurable(tune.grid_search)
-    dir = 'baselines/grid_search_configs/n_agents_grid_search'
+    dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search'
     gin.parse_config_file(dir + '/config.gin')
     run_grid_search(local_dir=dir)
diff --git a/train.py b/train.py
index 2f49720..ecea536 100644
--- a/train.py
+++ b/train.py
@@ -19,7 +19,7 @@ from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
 from ray.tune.registry import register_env
 from ray.rllib.models import ModelCatalog
 from ray.tune.logger import pretty_print
-from ray.rllib.models.preprocessors import Preprocessor
+from baselines.CustomPreprocessor import CustomPreprocessor
 
 
 import ray
@@ -30,14 +30,8 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
 # RailEnv.__bases__ = (RailEnv.__bases__[0], MultiAgentEnv)
 
 
-class MyPreprocessorClass(Preprocessor):
-    def _init_shape(self, obs_space, options):
-        return (105,)
 
-    def transform(self, observation):
-        return observation  # return the preprocessed observation
-
-ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
+ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
 ray.init()
 
 def train(config):
@@ -93,19 +87,20 @@ def train(config):
         return f"ppo_policy"
 
     agent_config = ppo.DEFAULT_CONFIG.copy()
-    agent_config['model'] = {"fcnet_hiddens": [32, 32]}#, "custom_preprocessor": "my_prep"}
+    agent_config['model'] = {"fcnet_hiddens": [32, 32], "custom_preprocessor": "my_prep"}
     agent_config['multiagent'] = {"policy_graphs": policy_graphs,
                                   "policy_mapping_fn": policy_mapping_fn,
                                   "policies_to_train": list(policy_graphs.keys())}
     agent_config["horizon"] = 50
-    #agent_config["num_workers"] = 0
+    agent_config["num_workers"] = 0
+    # agent_config["sample_batch_size"]: 1000
     #agent_config["num_cpus_per_worker"] = 40
     #agent_config["num_gpus"] = 2.0
     #agent_config["num_gpus_per_worker"] = 2.0
     #agent_config["num_cpus_for_driver"] = 5
     #agent_config["num_envs_per_worker"] = 15
     agent_config["env_config"] = env_config
-    agent_config["batch_mode"] = "complete_episodes"
+    #agent_config["batch_mode"] = "complete_episodes"
 
     ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config)
 
-- 
GitLab