From d68164dba751c0696b48c0b6d9533a7918138118 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume@iccluster091.iccluster.epfl.ch>
Date: Wed, 15 May 2019 17:14:27 +0200
Subject: [PATCH] some changes for more convenient trainer modification

---
 grid_search_train.py | 77 +++++++++++++++++---------------------------
 1 file changed, 30 insertions(+), 47 deletions(-)

diff --git a/grid_search_train.py b/grid_search_train.py
index 1d06bea..b30cb40 100644
--- a/grid_search_train.py
+++ b/grid_search_train.py
@@ -4,9 +4,11 @@ import gym
 
 from flatland.envs.generators import complex_rail_generator
 
-import ray.rllib.agents.ppo.ppo as ppo
-from ray.rllib.agents.ppo.ppo import PPOTrainer
-from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
+
+# Import PPO trainer: we can replace these imports by any other trainer from RLLib.
+from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
+from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
+from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
 
 from ray.rllib.models import ModelCatalog
 from ray.tune.logger import pretty_print
@@ -27,6 +29,7 @@ from ray import tune
 ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
 ray.init(object_store_memory=150000000000)
 
+
 def train(config, reporter):
     print('Init Env')
 
@@ -42,66 +45,47 @@ def train(config, reporter):
                               1,  # Case 1c (9)  - simple turn left
                               1]  # Case 2b (10) - simple switch mirrored
 
-    # Example generate a random rail
-    """
-    env = RailEnv(width=10,
-                  height=10,
-                  rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-                  number_of_agents=1)
-    """
+    # Example configuration to generate a random rail
     env_config = {"width":config['map_width'],
                   "height":config['map_height'],
                   "rail_generator":complex_rail_generator(nr_start_goal=config['n_agents'], min_dist=5, max_dist=99999, seed=0),
                   "number_of_agents":config['n_agents']}
-    """
-    env = RailEnv(width=20,
-                  height=20,
-                  rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
-                          ['../notebooks/temp.npy']),
-                  number_of_agents=3)
-
-    """
-
-
-
-    # Example generate a random rail
-    # env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'],
-    #               rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
-    #               number_of_agents=config["n_agents"])
 
+    # Observation space and action space definitions
     obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
     act_space = gym.spaces.Discrete(4)
 
     # Dict with the different policies to train
     policy_graphs = {
-        config['policy_folder_name'].format(**locals()): (PPOPolicyGraph, obs_space, act_space, {})
+        config['policy_folder_name'].format(**locals()): (PolicyGraph, obs_space, act_space, {})
     }
 
     def policy_mapping_fn(agent_id):
         return config['policy_folder_name'].format(**locals())
 
-    agent_config = ppo.DEFAULT_CONFIG.copy()
-    agent_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
-    agent_config['multiagent'] = {"policy_graphs": policy_graphs,
+
+    # Trainer configuration
+    trainer_config = DEFAULT_CONFIG.copy()
+    trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": "my_prep"}
+    trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
                                   "policy_mapping_fn": policy_mapping_fn,
                                   "policies_to_train": list(policy_graphs.keys())}
-    agent_config["horizon"] = config['horizon']
-
-    agent_config["num_workers"] = 0
-    agent_config["num_cpus_per_worker"] = 10
-    agent_config["num_gpus"] = 0.5
-    agent_config["num_gpus_per_worker"] = 0.5
-    agent_config["num_cpus_for_driver"] = 2
-    agent_config["num_envs_per_worker"] = 10
-    agent_config["env_config"] = env_config
-    agent_config["batch_mode"] = "complete_episodes"
-    agent_config['simple_optimizer'] = False
+    trainer_config["horizon"] = config['horizon']
+
+    trainer_config["num_workers"] = 0
+    trainer_config["num_cpus_per_worker"] = 10
+    trainer_config["num_gpus"] = 0.5
+    trainer_config["num_gpus_per_worker"] = 0.5
+    trainer_config["num_cpus_for_driver"] = 2
+    trainer_config["num_envs_per_worker"] = 10
+    trainer_config["env_config"] = env_config
+    trainer_config["batch_mode"] = "complete_episodes"
+    trainer_config['simple_optimizer'] = False
 
     def logger_creator(conf):
         """Creates a Unified logger with a default logdir prefix
         containing the agent name and the env id
         """
-        print("FOLDER", config['policy_folder_name'])
         logdir = config['policy_folder_name'].format(**locals())
         logdir = tempfile.mkdtemp(
             prefix=logdir, dir=config['local_dir'])
@@ -109,19 +93,18 @@ def train(config, reporter):
 
     logger = logger_creator
 
-    ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config, logger_creator=logger)
+    trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config, logger_creator=logger)
 
     for i in range(100000 + 2):
         print("== Iteration", i, "==")
 
-        print("-- PPO --")
-        print(pretty_print(ppo_trainer.train()))
+        print(pretty_print(trainer.train()))
 
         if i % config['save_every'] == 0:
-            checkpoint = ppo_trainer.save()
+            checkpoint = trainer.save()
             print("checkpoint saved at", checkpoint)
 
-        reporter(num_iterations_trained=ppo_trainer._iteration)
+        reporter(num_iterations_trained=trainer._iteration)
 
 
 @gin.configurable
@@ -151,6 +134,6 @@ def run_grid_search(name, num_iterations, n_agents, hidden_sizes, save_every,
 
 if __name__ == '__main__':
     gin.external_configurable(tune.grid_search)
-    dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search'
+    dir = '/mount/SDC/flatland/baselines/grid_search_configs/n_agents_grid_search'  # To Modify
     gin.parse_config_file(dir + '/config.gin')
     run_grid_search(local_dir=dir)
-- 
GitLab