From 014f2d747e1e1d8d3e31bd33d0894b7ab79c36bd Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sun, 18 Aug 2019 07:49:43 -0400
Subject: [PATCH] priority to connect to intersection instead of other city.
 minor bug fixes in node connecting algorithm

---
 flatland/envs/generators.py                      | 6 ++++++
 tests/test_flatland_env_sparse_rail_generator.py | 6 +++---
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 27508c22..b571a003 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -874,11 +874,16 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             current_node = node_stack[0]
             delete_idx = np.where(available_nodes_full == current_node)
             available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
+
             if current_node < num_cities and len(available_intersections) > 0:
                 available_nodes = available_intersections
+                delete_idx = np.where(available_cities == current_node)
+
                 available_cities = np.delete(available_cities, delete_idx, 0)
             elif len(available_intersections) > 0:
                 available_nodes = available_cities
+                delete_idx = np.where(available_intersections == current_node)
+
                 available_intersections = np.delete(available_intersections, delete_idx, 0)
             else:
                 available_nodes = available_nodes_full
@@ -892,6 +897,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             if len(available_nodes) >= num_neighb:
                 connected_neighb_idx = available_nodes[
                                        0:num_neighb]  # np.random.choice(available_nodes, num_neighb, replace=False)
+                print(current_node, "-->", connected_neighb_idx)
             else:
                 connected_neighb_idx = available_nodes
 
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index b20754be..48d4c577 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -30,12 +30,12 @@ def test_sparse_rail_generator():
                                                        num_trainstations=10,  # Number of possible start/targets on map
                                                        min_node_dist=10,  # Minimal distance of nodes
                                                        node_radius=2,  # Proximity of stations to city center
-                                                       num_neighb=2,  # Number of connections to other cities
-                                                       seed=15,  # Random seed
+                                                       num_neighb=1,  # Number of connections to other cities
+                                                       seed=5,  # Random seed
                                                        ),
                   number_of_agents=0,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
-    time.sleep(5)
+    time.sleep(19)
-- 
GitLab