From 44635204283dd5a026d067d1168d50f8742ade8c Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 25 Sep 2019 14:27:26 -0400
Subject: [PATCH] fixed bug in randomly placed cities

---
 examples/flatland_2_0_example.py |  4 ++--
 flatland/envs/rail_generators.py | 31 ++++++++++++++++---------------
 2 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 339945f3..2334450c 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -35,11 +35,11 @@ env = RailEnv(width=50,
               rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
                                                    num_intersections=0,  # Number of intersections (no start / target)
                                                    num_trainstations=15,  # Number of possible start/targets on map
-                                                   min_node_dist=3,  # Minimal distance of nodes
+                                                   min_node_dist=10,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
                                                    num_neighb=2,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
-                                                   grid_mode=True,
+                                                   grid_mode=False,
                                                    enhance_intersection=False
                                                    ),
               schedule_generator=sparse_schedule_generator(),
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index f3d7b889..c0721f62 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -604,26 +604,30 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         # Start at some node
         current_node = np.random.randint(len(available_nodes_full))
         node_stack = [current_node]
+        open_nodes = np.copy(available_nodes_full)
         allowed_connections = num_neighb
-        first_node = True
         i = 0
         boarder_connections = set()
-        while len(node_stack) > 0:
-            current_node = node_stack[0]
-            delete_idx = np.where(available_nodes_full == current_node)
-            available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
+        while len(open_nodes) > 0:
+            if len(node_stack) > 0:
+                current_node = node_stack[0]
+            else:
+                current_node = np.random.choice(open_nodes)
+                node_stack.append(current_node)
+            delete_idx = np.where(open_nodes == current_node)
+            open_nodes = np.delete(open_nodes, delete_idx, 0)
 
             # Priority city to intersection connections
             if current_node < _num_cities and len(available_intersections) > 0:
                 available_nodes = available_intersections
                 delete_idx = np.where(available_cities == current_node)
-                available_cities = np.delete(available_cities, delete_idx, 0)
+                # available_cities = np.delete(available_cities, delete_idx, 0)
 
             # Priority intersection to city connections
             elif current_node >= _num_cities and len(available_cities) > 0:
                 available_nodes = available_cities
                 delete_idx = np.where(available_intersections == current_node)
-                available_intersections = np.delete(available_intersections, delete_idx, 0)
+                # available_intersections = np.delete(available_intersections, delete_idx, 0)
 
             # If no options possible connect to whatever node is still available
             else:
@@ -637,18 +641,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
             # Set number of neighboring nodes
             if len(available_nodes) >= allowed_connections:
-                connected_neighb_idx = available_nodes[:allowed_connections]
+                connected_neighb_idx = available_nodes[1:allowed_connections + 1]
             else:
                 connected_neighb_idx = available_nodes
+
             print(current_node, connected_neighb_idx)
-            # Less connections for subsequent nodes
-            if first_node:
-                allowed_connections -= 1
-                first_node = False
 
             # Connect to the neighboring nodes
             for neighb in connected_neighb_idx:
-                if neighb not in node_stack:
+                if neighb not in node_stack and neighb in open_nodes:
                     node_stack.append(neighb)
 
                 dist_from_center = distance_on_rail(node_positions[current_node], node_positions[neighb])
@@ -824,8 +825,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             tries = 0
 
             while to_close:
-                x_tmp = node_radius + np.random.randint(height - node_radius - 1)
-                y_tmp = node_radius + np.random.randint(width - node_radius - 1)
+                x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1)
+                y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1)
                 to_close = False
 
                 # Check distance to cities
-- 
GitLab