From 36d8f9e02ab4d045f7d6d4308c15ceffb2e675bb Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 19 Aug 2019 08:36:54 -0400
Subject: [PATCH] added realistic_mode for less random levels

---
 flatland/envs/generators.py                   | 58 +++++++++++++------
 ...test_flatland_env_sparse_rail_generator.py | 14 ++---
 2 files changed, 46 insertions(+), 26 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 2826ee51..dad07a60 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -813,8 +813,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0):
 
 
 def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstations=2, min_node_dist=20, node_radius=2,
-                          num_neighb=4,
-                          seed=0):
+                          num_neighb=4, realistic_mode=False, seed=0):
     '''
 
     :param nr_train_stations:
@@ -843,26 +842,46 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
         city_positions = []
         intersection_positions = []
 
+        # Evenly distribute cities and intersections
+        if realistic_mode:
+            tot_num_node = num_intersections + num_cities
+            nodes_ratio = height / width
+            nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio)))
+            nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
+            x_positions = np.linspace(2, height - 2, nodes_per_row, dtype=int)
+            y_positions = np.linspace(2, width - 2, nodes_per_col, dtype=int)
         for node_idx in range(num_cities + num_intersections):
             to_close = True
             tries = 0
-            while to_close:
-                x_tmp = 1 + np.random.randint(height - 2)
-                y_tmp = 1 + np.random.randint(width - 2)
-                to_close = False
-                for node_pos in node_positions:
-                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
-                        to_close = True
-                if not to_close:
-                    node_positions.append((x_tmp, y_tmp))
-                    if node_idx < num_cities:
-                        city_positions.append((x_tmp, y_tmp))
-                    else:
-                        intersection_positions.append((x_tmp, y_tmp))
-                tries += 1
-                if tries > 100:
-                    warnings.warn("Could not set nodes, please change initial parameters!!!!")
-                    break
+            if not realistic_mode:
+                while to_close:
+                    x_tmp = 1 + np.random.randint(height - 2)
+                    y_tmp = 1 + np.random.randint(width - 2)
+                    to_close = False
+                    for node_pos in node_positions:
+                        if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                            to_close = True
+                    if not to_close:
+                        node_positions.append((x_tmp, y_tmp))
+                        if node_idx < num_cities:
+                            city_positions.append((x_tmp, y_tmp))
+                        else:
+                            intersection_positions.append((x_tmp, y_tmp))
+                    tries += 1
+                    if tries > 100:
+                        warnings.warn("Could not set nodes, please change initial parameters!!!!")
+                        break
+            else:
+                x_tmp = x_positions[node_idx % nodes_per_row]
+                y_tmp = y_positions[node_idx // nodes_per_row]
+                if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0:
+                    city_positions.append((x_tmp, y_tmp))
+                else:
+                    intersection_positions.append((x_tmp, y_tmp))
+
+        if realistic_mode:
+            node_positions = city_positions + intersection_positions
+
         # Chose node connection
         available_nodes_full = np.arange(num_cities + num_intersections)
         available_cities = np.arange(num_cities)
@@ -886,6 +905,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 available_intersections = np.delete(available_intersections, delete_idx, 0)
             else:
                 available_nodes = available_nodes_full
+
             # Sort available neighbors according to their distance.
             node_dist = []
             for av_node in available_nodes:
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 92744080..74513b72 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -23,17 +23,17 @@ def test_realistic_rail_generator():
 
 
 def test_sparse_rail_generator():
-    env = RailEnv(width=20,
-                  height=20,
-                  rail_generator=sparse_rail_generator(num_cities=5,  # Number of cities in map
-                                                       num_intersections=2,  # Number of interesections in map
-                                                       num_trainstations=20,  # Number of possible start/targets on map
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                                       num_intersections=10,  # Number of interesections in map
+                                                       num_trainstations=50,  # Number of possible start/targets on map
                                                        min_node_dist=6,  # Minimal distance of nodes
                                                        node_radius=3,  # Proximity of stations to city center
-                                                       num_neighb=2,  # Number of connections to other cities
+                                                       num_neighb=4,  # Number of connections to other cities
                                                        seed=5,  # Random seed
                                                        ),
-                  number_of_agents=1,
+                  number_of_agents=45,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
-- 
GitLab