diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 2aa148999d2398ac6ab0696388f06bf4e93b8fab..27508c222868448d53755012b3ef2ce0da501232 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -838,7 +838,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
         np.random.seed(seed + num_resets)
 
         # Generate a set of nodes for the sparse network
+        # Try to connect cities to nodes first
         node_positions = []
+        city_positions = []
+        intersection_positions = []
         for node_idx in range(num_cities + num_intersections):
             to_close = True
             tries = 0
@@ -851,21 +854,34 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                         to_close = True
                 if not to_close:
                     node_positions.append((x_tmp, y_tmp))
+                    if node_idx < num_cities:
+                        city_positions.append((x_tmp, y_tmp))
+                    else:
+                        intersection_positions.append((x_tmp, y_tmp))
                 tries += 1
                 if tries > 100:
                     warnings.warn("Could not set nodes, please change initial parameters!!!!")
                     break
 
         # Chose node connection
-        available_nodes = np.arange(num_cities + num_intersections)
+        available_nodes_full = np.arange(num_cities + num_intersections)
+        available_cities = np.arange(num_cities)
+        available_intersections = np.arange(num_cities, num_cities + num_intersections)
         current_node = 0
         node_stack = [current_node]
 
         while len(node_stack) > 0:
             current_node = node_stack[0]
-            delete_idx = np.where(available_nodes == current_node)
-            available_nodes = np.delete(available_nodes, delete_idx, 0)
-
+            delete_idx = np.where(available_nodes_full == current_node)
+            available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
+            if current_node < num_cities and len(available_intersections) > 0:
+                available_nodes = available_intersections
+                available_cities = np.delete(available_cities, delete_idx, 0)
+            elif len(available_intersections) > 0:
+                available_nodes = available_cities
+                available_intersections = np.delete(available_intersections, delete_idx, 0)
+            else:
+                available_nodes = available_nodes_full
             # Sort available neighbors according to their distance.
             node_dist = []
             for av_node in available_nodes:
@@ -885,30 +901,36 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                     node_stack.append(neighb)
                 connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
             node_stack.pop(0)
+
         # Place train stations close to the node
         # We currently place them uniformly distirbuted among all cities
-        train_stations = [[] for i in range(num_cities)]
-
-        for station in range(num_trainstations):
-            trainstation_node = int(station / num_trainstations * num_cities)
-
-            station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0,
-                                height - 1)
-            station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), 0,
-                                width - 1)
-            while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
-                trainstation_node] or \
-                rail_array[(station_x, station_y)] != 0:
+        if num_cities > 1:
+            train_stations = [[] for i in range(num_cities)]
+
+            for station in range(num_trainstations):
+                trainstation_node = int(station / num_trainstations * num_cities)
+
                 station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
                                     0,
                                     height - 1)
                 station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
                                     0,
                                     width - 1)
-            train_stations[trainstation_node].append((station_x, station_y))
-
-            # Connect train station to the correct node
-            connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y))
+                while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
+                    trainstation_node] or \
+                    rail_array[(station_x, station_y)] != 0:
+                    station_x = np.clip(
+                        node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
+                        0,
+                        height - 1)
+                    station_y = np.clip(
+                        node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
+                        0,
+                        width - 1)
+                train_stations[trainstation_node].append((station_x, station_y))
+
+                # Connect train station to the correct node
+                connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y))
 
         # Fix all nodes with illegal transition maps
         for current_node in node_positions:
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index c4430e6e4d962d9011e2c973e8f342f09f647c38..b20754be9fdc76ee57c25da4fe211ecdceb55d04 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -1,3 +1,5 @@
+import time
+
 import numpy as np
 
 from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator
@@ -23,17 +25,17 @@ def test_realistic_rail_generator():
 def test_sparse_rail_generator():
     env = RailEnv(width=50,
                   height=50,
-                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
-                                                       num_intersections=3,  # Number of interesections in map
+                  rail_generator=sparse_rail_generator(num_cities=5,  # Number of cities in map
+                                                       num_intersections=2,  # Number of interesections in map
                                                        num_trainstations=10,  # Number of possible start/targets on map
                                                        min_node_dist=10,  # Minimal distance of nodes
                                                        node_radius=2,  # Proximity of stations to city center
-                                                       num_neighb=4,  # Number of connections to other cities
+                                                       num_neighb=2,  # Number of connections to other cities
                                                        seed=15,  # Random seed
                                                        ),
-                  number_of_agents=1,
+                  number_of_agents=0,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
-
+    time.sleep(5)