diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index e901da5bf9331a4041d267e8692ab408624a1a11..28c78cb6332e9084634cd363df70355b29b38f64 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -37,11 +37,11 @@ env = RailEnv(width=50,
                                                    node_radius=4,  # Proximity of stations to city center
                                                    seed=0,  # Random seed
                                                    grid_mode=True,
-                                                   max_connection_points_per_side=2,
-                                                   max_nr_connection_directions=2
+                                                   max_inter_city_rails=2,
+                                                   tracks_in_city=4,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
-              number_of_agents=5,
+              number_of_agents=10,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=GlobalObsForRailEnv())
 
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index cf163e909941ee44b4da0af71c24534a7d1535d4..e381733faeb1c83ef72a3e19f10f17dce4d02966 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -533,8 +533,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
 
 
 def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
-                          grid_mode=False, max_connection_points_per_side=4,
-                          max_nr_connection_directions=2,
+                          grid_mode=False, max_inter_city_rails=4, tracks_in_city=4,
                           seed=0) -> RailGenerator:
     """
     This is a level generator which generates complex sparse rail configurations
@@ -578,11 +577,11 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
 
         # Set up connection points for all cities
         connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius,
-                                                                              max_connection_points_per_side,
-                                                                              max_nr_connection_directions)
+                                                                              tracks_in_city)
 
         # Connect the cities through the connection points
-        _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map)
+        _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails,
+                        rail_trans, grid_map)
 
         # Build inner cities
         _build_inner_cities(node_positions, connection_points, rail_trans, grid_map)
@@ -617,8 +616,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
             tries = 0
 
             while to_close:
-                x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1)
-                y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1)
+                x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius - 1)
+                y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius - 1)
                 to_close = False
 
                 # Check distance to nodes
@@ -644,8 +643,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
         nodes_ratio = height / width
         nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
         nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
-        x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
-        y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
+        x_positions = np.linspace(node_radius + 1, height - node_radius - 2, nodes_per_row, dtype=int)
+        y_positions = np.linspace(node_radius + 1, width - node_radius - 2, nodes_per_col, dtype=int)
         node_positions = []
         city_cells = []
         for node_idx in range(nb_nodes):
@@ -655,13 +654,11 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
             city_cells.extend(_city_cells(node_positions[-1], node_radius))
         return node_positions, city_cells
 
-    def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2,
-                                         max_nr_connection_directions=2):
+    def _generate_node_connection_points(node_positions, node_size, tracks_in_city=2):
         connection_points = []
         connection_info = []
-        max_nr_connection_directions = np.clip(max_nr_connection_directions, 0, 4)
-        if max_nr_connection_points > 2 * node_size + 1:
-            max_nr_connection_points = 2 * node_size + 1
+        if tracks_in_city > 2 * node_size + 1:
+            tracks_in_city = 2 * node_size + 1
 
         for node_position in node_positions:
 
@@ -671,21 +668,17 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
                 neighb_dist.append(distance_on_rail(node_position, neighb_node))
             closest_neighb_idx = argsort(neighb_dist)
 
-            # Store the directions to these neighbours
+            # Store the directions to these neighbours and orient city to face closest neighbour
             connection_sides_idx = []
             idx = 1
-            while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist):
-                current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
-                if current_closest_direction not in connection_sides_idx:
-                    connection_sides_idx.append(current_closest_direction)
-                idx += 1
+            current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
+            connection_sides_idx.append(current_closest_direction)
+            connection_sides_idx.append((current_closest_direction + 2) % 4)
 
-            # set the number of connection points for each direction
+            # set the number of tracks within a city, at least 2 tracks per city
             connections_per_direction = np.zeros(4, dtype=int)
-
+            nr_of_connection_points = np.random.randint(2, tracks_in_city + 1)
             for idx in connection_sides_idx:
-                nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1)
-
                 connections_per_direction[idx] = nr_of_connection_points
             connection_points_coordinates = [[] for i in range(4)]
 
@@ -710,7 +703,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
             connection_info.append(connections_per_direction)
         return connection_points, connection_info
 
-    def _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map):
+    def _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails,
+                        rail_trans, grid_map):
         """
         Function to connect the different cities through their connection points
         :param node_positions: Positions of city centers
@@ -723,6 +717,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
         boarder_connections = set()
         for current_node in np.arange(len(node_positions)):
             direction = 0
+            connected_to_city = []
             for nbr_connection_points in connection_info[current_node]:
                 if nbr_connection_points > 0:
                     neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions)
@@ -730,23 +725,35 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
                     direction += 1
                     continue
 
-                if neighb_idx is not None:
-                    connection_distances = []
-                    for tmp_out_connection_point in connection_points[current_node][direction]:
-                        # Find closest connection point
-                        min_connection_dist = np.inf
-                        all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in
-                                                        sublist]
-
-                        for tmp_in_connection_point in all_neighb_connection_points:
-                            tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
-                            if tmp_dist < min_connection_dist:
-                                min_connection_dist = tmp_dist
-                                neighb_connection_point = tmp_in_connection_point
-                        connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
-                                       city_cells)
-                        boarder_connections.add((tmp_out_connection_point, current_node))
-                        boarder_connections.add((neighb_connection_point, neighb_idx))
+                if neighb_idx is None or neighb_idx in connected_to_city:
+                    node_dist = []
+                    for av_node in node_positions:
+                        node_dist.append(distance_on_rail(node_positions[current_node], av_node))
+                    i = 1
+                    neighbours = np.argsort(node_dist)
+                    neighb_idx = neighbours[i]
+                    while neighb_idx in connected_to_city:
+                        i += 1
+                        neighb_idx = neighbours[i]
+
+                connected_to_city.append(neighb_idx)
+                number_of_out_rails = np.random.randint(1, max_inter_city_rails + 1)
+
+                for tmp_out_connection_point in connection_points[current_node][direction][:number_of_out_rails]:
+                    # Find closest connection point
+                    min_connection_dist = np.inf
+                    all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in
+                                                    sublist]
+
+                    for tmp_in_connection_point in all_neighb_connection_points:
+                        tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
+                        if tmp_dist < min_connection_dist:
+                            min_connection_dist = tmp_dist
+                            neighb_connection_point = tmp_in_connection_point
+                    connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
+                                   city_cells)
+                    boarder_connections.add((tmp_out_connection_point, current_node))
+                    boarder_connections.add((neighb_connection_point, neighb_idx))
                 direction += 1
         return boarder_connections