From c4df1ca08118e27704ba3dca690c35a1c3b77598 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Tue, 14 May 2019 14:06:34 +0200
Subject: [PATCH] up to date environment initialization

---
 .../n_agents_grid_search/config.gin           |  6 +--
 grid_search_train.py                          | 47 ++++++++++++++-----
 train.py                                      | 42 +++++++++++------
 3 files changed, 65 insertions(+), 30 deletions(-)

diff --git a/grid_search_configs/n_agents_grid_search/config.gin b/grid_search_configs/n_agents_grid_search/config.gin
index 71035d4..9830838 100644
--- a/grid_search_configs/n_agents_grid_search/config.gin
+++ b/grid_search_configs/n_agents_grid_search/config.gin
@@ -3,9 +3,9 @@ run_grid_search.num_iterations = 1002
 run_grid_search.save_every = 200
 run_grid_search.hidden_sizes = [32, 32]
 
-run_grid_search.map_width = 15
-run_grid_search.map_height = 15
-run_grid_search.n_agents = {"grid_search": [1, 2, 3, 4]}
+run_grid_search.map_width = 50
+run_grid_search.map_height = 50
+run_grid_search.n_agents = {"grid_search": [2, 5, 10, 20]}
 
 run_grid_search.horizon = 50
 
diff --git a/grid_search_train.py b/grid_search_train.py
index 919c012..0689ec0 100644
--- a/grid_search_train.py
+++ b/grid_search_train.py
@@ -40,23 +40,44 @@ def train(config, reporter):
 
     env_name = f"rail_env_{config['n_agents']}"  # To modify if different environments configs are explored.
 
-    # Example generate a rail given a manual specification,
-    # a map of tuples (cell_type, rotation)
-    transition_probability = [0.5,  # empty cell - Case 0
-                              1.0,  # Case 1 - straight
-                              1.0,  # Case 2 - simple switch
-                              0.3,  # Case 3 - diamond drossing
-                              0.5,  # Case 4 - single slip
-                              0.5,  # Case 5 - double slip
-                              0.2,  # Case 6 - symmetrical
-                              0.0]  # Case 7 - dead end
+    transition_probability = [15,  # empty cell - Case 0
+                              5,  # Case 1 - straight
+                              5,  # Case 2 - simple switch
+                              1,  # Case 3 - diamond crossing
+                              1,  # Case 4 - single slip
+                              1,  # Case 5 - double slip
+                              1,  # Case 6 - symmetrical
+                              0,  # Case 7 - dead end
+                              1,  # Case 1b (8)  - simple turn right
+                              1,  # Case 1c (9)  - simple turn left
+                              1]  # Case 2b (10) - simple switch mirrored
+
+    # Example generate a random rail
+    """
+    env = RailEnv(width=10,
+                  height=10,
+                  rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+                  number_of_agents=1)
+    """
+    env = RailEnv(width=config['map_width'],
+                  height=config['map_height'],
+                  rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
+                  number_of_agents=config['n_agents'])
+    """
+    env = RailEnv(width=20,
+                  height=20,
+                  rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
+                          ['../notebooks/temp.npy']),
+                  number_of_agents=3)
+
+    """
 
 
 
     # Example generate a random rail
-    env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'],
-                  rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
-                  number_of_agents=config["n_agents"])
+    # env = RailEnvRLLibWrapper(width=config['map_width'], height=config['map_height'],
+    #               rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
+    #               number_of_agents=config["n_agents"])
 
     register_env(env_name, lambda _: env)
 
diff --git a/train.py b/train.py
index 4e887c8..71f214c 100644
--- a/train.py
+++ b/train.py
@@ -45,23 +45,37 @@ def train(config):
     random.seed(1)
     np.random.seed(1)
 
-    # Example generate a rail given a manual specification,
-    # a map of tuples (cell_type, rotation)
-    transition_probability = [0.5,  # empty cell - Case 0
-                              1.0,  # Case 1 - straight
-                              1.0,  # Case 2 - simple switch
-                              0.3,  # Case 3 - diamond drossing
-                              0.5,  # Case 4 - single slip
-                              0.5,  # Case 5 - double slip
-                              0.2,  # Case 6 - symmetrical
-                              0.0]  # Case 7 - dead end
-
-
+    transition_probability = [15,  # empty cell - Case 0
+                              5,  # Case 1 - straight
+                              5,  # Case 2 - simple switch
+                              1,  # Case 3 - diamond crossing
+                              1,  # Case 4 - single slip
+                              1,  # Case 5 - double slip
+                              1,  # Case 6 - symmetrical
+                              0,  # Case 7 - dead end
+                              1,  # Case 1b (8)  - simple turn right
+                              1,  # Case 1c (9)  - simple turn left
+                              1]  # Case 2b (10) - simple switch mirrored
 
     # Example generate a random rail
-    env = RailEnvRLLibWrapper(width=15, height=15,
-                  rail_generator=complex_rail_generator(nr_start_goal=1, nr_extra=20, min_dist=12),
+    """
+    env = RailEnv(width=10,
+                  height=10,
+                  rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
                   number_of_agents=1)
+    """
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
+                  number_of_agents=5)
+    """
+    env = RailEnv(width=20,
+                  height=20,
+                  rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
+                          ['../notebooks/temp.npy']),
+                  number_of_agents=3)
+
+    """
 
     register_env("railenv", lambda _: env)
     # if config['render']:
-- 
GitLab