diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 898f6a9906e998e70ac41f0833e57a4df2fbcac1..ca9346fc71a8d5485d7b71098dc91558d09bd929 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -32,16 +32,16 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=5,  # Number of cities in map (where train stations are)
-                                                   num_intersections=4,  # Number of intersections (no start / target)
+              rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
                                                    num_trainstations=100,  # Number of possible start/targets on map
                                                    min_node_dist=10,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
                                                    num_neighb=2,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
                                                    grid_mode=True,
-                                                   nr_inter_connections=2,
-                                                   max_nr_connection_points=12
+                                                   nr_parallel_tracks=2,
+                                                   connectin_points_per_side=3,
+                                                   max_nr_connection_directions=2,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
               number_of_agents=50,
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 147b1bd14a5943f87f4d78f37e9ea7e1722c9087..117862a71593f8483928323574d8977898cfdc9c 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -528,8 +528,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
     return generator
 
 
-def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2,
-                          num_neighb=3, nr_inter_connections=2, grid_mode=False, max_nr_connection_points=4,
+def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2,
+                          num_neighb=3, nr_parallel_tracks=2, grid_mode=False, connectin_points_per_side=4,
+                          max_nr_connection_directions=2,
                           seed=0) -> RailGenerator:
     """
     This is a level generator which generates complex sparse rail configurations
@@ -566,7 +567,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
         # Evenly distribute cities and intersections
         node_positions: List[Any] = None
-        nb_nodes = num_cities + num_intersections
+        nb_nodes = num_cities
         if grid_mode:
             nodes_ratio = height / width
             nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
@@ -591,14 +592,14 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
         nb_nodes = len(node_positions)
         _num_cities = len(city_positions)
-        _num_intersections = len(intersection_positions)
 
         # Chose node connection
         # Set up list of available nodes to connect to
         available_nodes = np.arange(nb_nodes)
 
         # Set up connection points for all cities
-        connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points)
+        connection_points = _generate_node_connection_points(node_positions, node_radius, connectin_points_per_side,
+                                                             max_nr_connection_directions)
 
         # Start at some node
         current_node = np.random.randint(len(available_nodes))
@@ -639,8 +640,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                     tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb])
                     connection_distances.append(tmp_dist_to_node)
                 possible_connection_points = argsort(connection_distances)
-                for sort_idx in possible_connection_points[:nr_inter_connections]:
-                    # Find closes connection point
+                for sort_idx in possible_connection_points[:nr_parallel_tracks]:
+                    # Find closest connection point
                     tmp_out_connection_point = connection_points[current_node][sort_idx]
                     min_connection_dist = np.inf
                     for tmp_in_connection_point in connection_points[neighb]:
@@ -705,7 +706,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                                                connection_points[trainstation_node][corner_node_idx],
                                                (station_x, station_y))
                     if len(connection) != 0:
-
                         if (connection_points[trainstation_node][corner_node_idx],
                             trainstation_node) in boarder_connections:
                             boarder_connections.remove(
@@ -748,6 +748,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             for tbd in to_be_deleted:
                 boarder_connections.remove(tbd)
             print(boarder_connections)
+
         # Fix all nodes with illegal transition maps
         empty_to_fix = []
         rails_to_fix = []
@@ -866,24 +867,28 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         node_positions = city_positions + intersection_positions
         return node_positions
 
-    def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2):
+    def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2,
+                                         max_nr_connection_directions=2):
         connection_points = []
         for node_position in node_positions:
-            n_connection_points = max_nr_connection_points  # np.random.randint(1, max_nr_connection_points)
-            connection_per_direction = n_connection_points // 4
-            connection_point_vector = [connection_per_direction, connection_per_direction, connection_per_direction,
-                                       n_connection_points - 3 * connection_per_direction]
+            connection_sides_idx = np.sort(
+                np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False))
+            connections_per_direction = np.zeros(4, dtype=int)
+
+            # set the number of connection points for each direction
+            for idx in connection_sides_idx:
+                connections_per_direction[idx] = max_nr_connection_points
             connection_points_coordinates = []
             random_connection_slots = False
             for direction in range(4):
                 if random_connection_slots:
                     connection_slots = np.random.choice(np.arange(-node_size, node_size),
-                                                        size=connection_point_vector[direction],
+                                                        size=connections_per_direction[direction],
                                                         replace=False)
                 else:
-                    connection_slots = np.arange(connection_point_vector[direction]) - int(
-                        connection_point_vector[direction] / 2)
-                for connection_idx in range(connection_point_vector[direction]):
+                    connection_slots = np.arange(connections_per_direction[direction]) - int(
+                        connections_per_direction[direction] / 2)
+                for connection_idx in range(connections_per_direction[direction]):
                     if direction == 0:
                         connection_points_coordinates.append(
                             (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))