From 4928e9063688efc3569deb29c9d907d4ef46a57e Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sun, 18 Aug 2019 08:20:12 -0400
Subject: [PATCH] updated number of node selection algorithm

---
 flatland/envs/generators.py                      | 12 +++++++-----
 tests/test_flatland_env_sparse_rail_generator.py | 10 +++++-----
 2 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 9d1bcfec..4463b1e3 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -869,7 +869,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
         available_intersections = np.arange(num_cities, num_cities + num_intersections)
         current_node = 0
         node_stack = [current_node]
-
+        allowed_connections = num_neighb
         while len(node_stack) > 0:
             current_node = node_stack[0]
             delete_idx = np.where(available_nodes_full == current_node)
@@ -883,7 +883,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             elif len(available_intersections) > 0:
                 available_nodes = available_cities
                 delete_idx = np.where(available_intersections == current_node)
-
                 available_intersections = np.delete(available_intersections, delete_idx, 0)
             else:
                 available_nodes = available_nodes_full
@@ -894,12 +893,15 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             available_nodes = available_nodes[np.argsort(node_dist)]
 
             # Set number of neighboring nodes
-            if len(available_nodes) >= num_neighb:
-                connected_neighb_idx = available_nodes[
-                                       0:num_neighb]  # np.random.choice(available_nodes, num_neighb, replace=False)
+
+            print(current_node, allowed_connections)
+            if len(available_nodes) >= allowed_connections:
+                connected_neighb_idx = available_nodes[:allowed_connections]
             else:
                 connected_neighb_idx = available_nodes
 
+            if current_node == 0:
+                allowed_connections -= 1
             # Connect to the neighboring nodes
             for neighb in connected_neighb_idx:
                 if neighb not in node_stack:
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 48d4c577..f49893ae 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -25,12 +25,12 @@ def test_realistic_rail_generator():
 def test_sparse_rail_generator():
     env = RailEnv(width=50,
                   height=50,
-                  rail_generator=sparse_rail_generator(num_cities=5,  # Number of cities in map
-                                                       num_intersections=2,  # Number of interesections in map
-                                                       num_trainstations=10,  # Number of possible start/targets on map
+                  rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map
+                                                       num_intersections=3,  # Number of interesections in map
+                                                       num_trainstations=5,  # Number of possible start/targets on map
                                                        min_node_dist=10,  # Minimal distance of nodes
                                                        node_radius=2,  # Proximity of stations to city center
-                                                       num_neighb=1,  # Number of connections to other cities
+                                                       num_neighb=3,  # Number of connections to other cities
                                                        seed=5,  # Random seed
                                                        ),
                   number_of_agents=0,
@@ -38,4 +38,4 @@ def test_sparse_rail_generator():
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
-    time.sleep(19)
+    time.sleep(10)
-- 
GitLab