diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index b3c24da79a3f16480405dfe5b7bf94902d513a0e..ceedc90a95f8a434be67a7533873d3eb00154537 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -21,14 +21,15 @@ stochastic_data = {'prop_malfunction': 0.5,  # Percentage of defective agents
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 env = RailEnv(width=20,
               height=20,
-              rail_generator=sparse_rail_generator(num_cities=5,  # Number of cities in map (where train stations are)
-                                                   num_intersections=4,  # Number of interesections (no start / target)
+              rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map (where train stations are)
+                                                   num_intersections=1,  # Number of interesections (no start / target)
                                                    num_trainstations=15,  # Number of possible start/targets on map
                                                    min_node_dist=3,  # Minimal distance of nodes
                                                    node_radius=3,  # Proximity of stations to city center
-                                                   num_neighb=3,  # Number of connections to other cities/intersections
+                                                   num_neighb=2,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
-                                                   realistic_mode=True
+                                                   realistic_mode=True,
+                                                   enhance_intersection=True
                                                    ),
               number_of_agents=5,
               stochastic_data=stochastic_data,  # Malfunction generator data
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 087c6299966f5d1e04e6d21809afbd73999be6c1..e3b59e4fccba235b008eb8c576f5bdb334061248 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -975,7 +975,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
         if num_agents > num_trainstations:
             num_agents = num_trainstations
-            warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
+            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
 
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
@@ -998,6 +998,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
             y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
 
+
         for node_idx in range(num_cities + num_intersections):
             to_close = True
             tries = 0
@@ -1043,18 +1044,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         available_cities = np.arange(num_cities)
         available_intersections = np.arange(num_cities, num_cities + num_intersections)
 
-        # Keep track of number of incoming connection
-        incoming_connections = np.zeros(num_intersections + num_cities)
-
         # Start at some node
         current_node = np.random.randint(len(available_nodes_full))
         node_stack = [current_node]
         allowed_connections = num_neighb
+        first_node = True
         while len(node_stack) > 0:
             current_node = node_stack[0]
             delete_idx = np.where(available_nodes_full == current_node)
             available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
-            filled_nodes = np.where(incoming_connections >= num_neighb)
             # Priority city to intersection connections
             if current_node < num_cities and len(available_intersections) > 0:
                 available_nodes = available_intersections
@@ -1083,16 +1081,16 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             else:
                 connected_neighb_idx = available_nodes
 
-            if current_node == 0:
+            # Less connections for subsequent nodes
+            if first_node:
                 allowed_connections -= 1
+                first_node = False
 
             # Connect to the neighboring nodes
             for neighb in connected_neighb_idx:
                 if neighb not in node_stack:
                     node_stack.append(neighb)
                 connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
-                incoming_connections[neighb] += 1
-                incoming_connections[current_node] += 1
             node_stack.pop(0)
 
         # Place train stations close to the node
@@ -1133,6 +1131,12 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                 if len(connection) == 0:
                     train_stations[trainstation_node].pop(-1)
 
+        # Adjust the number of agents if you could not build enough trainstations
+        built_num_trainstation = len(train_stations)
+        if num_agents > built_num_trainstation:
+            num_agents = built_num_trainstation
+            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
+
         # Place passing lanes at intersections
         # We currently place them uniformly distirbuted among all cities
         if enhance_intersection: