diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 13b7335e66492744e1b28a159ff5f51442860132..feb37909740f546106b43e0bb517f42c81499dd4 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -30,15 +30,14 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=12,  # Number of cities in map (where train stations are)
-                                                   node_radius=4,  # Proximity of stations to city center
+              rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
                                                    seed=0,  # Random seed
                                                    grid_mode=False,
                                                    max_inter_city_rails=2,
-                                                   max_tracks_in_city=5,
+                                                   max_tracks_in_city=8,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
-              number_of_agents=50,
+              number_of_agents=10,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=GlobalObsForRailEnv())
 
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index a86e95d2c486486341294dece158c7789218aea8..a9e9796fb52e0c60fcff98c5c39ae26e4e76eb20 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -532,23 +532,16 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
     return generator
 
 
-def sparse_rail_generator(num_cities=5, node_radius=2,
-                          grid_mode=False, max_inter_city_rails=4, max_tracks_in_city=4,
+def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, max_tracks_in_city=4,
                           seed=0) -> RailGenerator:
     """
-    This is a level generator which generates complex sparse rail configurations
-
-    :param num_cities: Number of city node (can hold trainstations)
-    :type num_cities: int
-    :param num_intersections: Number of intersection that city nodes can connect to
-    :param num_trainstations: Total number of trainstations in env
-    :param min_node_dist: Minimal distance between nodes
-    :param node_radius: Proximity of trainstations to center of city node
-    :param num_neighb: Number of neighbouring nodes each node connects to
-    :param grid_mode: True -> NOdes evenly distirbuted in env, False-> Random distribution of nodes
-    :param enhance_intersection: True -> Extra rail elements added at intersections
-    :param seed: Random Seed
-    :return: numpy.ndarray of type numpy.uint16 -- The matrix with the correct 16-bit bitmaps for each cell.
+    Generates railway networks with cities and inner city rails
+    :param num_cities: Number of city centers in the map
+    :param grid_mode: Arange cities in a grid or randomly
+    :param max_inter_city_rails: Maximum number of connecting rails going out from a city
+    :param max_tracks_in_city: maximum number of internal rails
+    :param seed: Random seed to initiate rail
+    :return: generator
     """
 
     def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct:
@@ -558,6 +551,7 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
         rail_array = grid_map.grid
         rail_array.fill(0)
         np.random.seed(seed + num_resets)
+        node_radius = int(max_tracks_in_city / 2) + 1
         max_inter_city_rails_allowed = max_inter_city_rails
         if max_inter_city_rails_allowed > max_tracks_in_city:
             max_inter_city_rails_allowed = max_tracks_in_city
@@ -570,9 +564,9 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
         node_positions: List[Any] = None
         nb_nodes = num_cities
         if grid_mode:
-            node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width)
+            node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width)
         else:
-            node_positions, city_cells = _generate_random_node_positions(nb_nodes, height, width)
+            node_positions, city_cells = _generate_random_node_positions(nb_nodes, node_radius, height, width)
 
         # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
         nb_nodes = len(node_positions)
@@ -587,11 +581,13 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
 
         # Build inner cities
         through_tracks = _build_inner_cities(node_positions, inner_connection_points, outer_connection_points,
+                                             node_radius,
                                              rail_trans,
                                              grid_map)
 
         # Populate cities
-        train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, through_tracks, grid_map)
+        train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, through_tracks,
+                                                                             node_radius, grid_map)
 
         # Adjust the number of agents if you could not build enough trainstations
         if num_agents > built_num_trainstation:
@@ -610,7 +606,7 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
             'train_stations': train_stations
         }}
 
-    def _generate_random_node_positions(nb_nodes, height, width):
+    def _generate_random_node_positions(nb_nodes, node_radius, height, width):
 
         node_positions = []
         city_cells = []
@@ -642,7 +638,7 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
 
         return node_positions, city_cells
 
-    def _generate_node_positions_grid_mode(nb_nodes, height, width):
+    def _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width):
         nodes_ratio = height / width
         nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
         nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
@@ -661,9 +657,6 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
         inner_connection_points = []
         outer_connection_points = []
         connection_info = []
-        if tracks_in_city > 2 * node_size - 1:
-            tracks_in_city = 2 * node_size - 1
-
         for node_position in node_positions:
 
             # Chose the directions where close cities are situated
@@ -760,7 +753,8 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
                 direction += 1
         return
 
-    def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, rail_trans, grid_map):
+    def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans,
+                            grid_map):
         """
         Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
         :param node_positions: Positions of the cities
@@ -790,7 +784,7 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
 
         return through_path_cells
 
-    def _set_trainstation_positions(node_positions, through_tracks, grid_map):
+    def _set_trainstation_positions(node_positions, through_tracks, node_radius, grid_map):
         """
 
         :param node_positions:
@@ -838,10 +832,11 @@ def sparse_rail_generator(num_cities=5, node_radius=2,
             avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
             avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
             # Set probability to choose start and stop from trainstations
-            sum_start = sum(node_available_start)
-            sum_target = sum(node_available_target)
-            p_avail_start = [float(i) / sum_start for i in node_available_start]
-            p_avail_target = [float(i) / sum_target for i in node_available_target]
+            sum_start = sum(np.array(node_available_start)[avail_start_nodes])
+            sum_target = sum(np.array(node_available_target)[avail_target_nodes])
+            p_avail_start = [float(i) / sum_start for i in np.array(node_available_start)[avail_start_nodes]]
+            p_avail_target = [float(i) / sum_target for i in np.array(node_available_target)[avail_target_nodes]]
+
             start_node = np.random.choice(avail_start_nodes, p=p_avail_start)
             target_node = np.random.choice(avail_target_nodes, p=p_avail_target)
             tries = 0