diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 0ec268dc1d3910967541ba0bcc357c8cc4638ec4..ada43ce4e25353cc8e6e639b5a3d1ce6b00c4953 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -582,7 +582,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
 
         # Connect the cities through the connection points
         city_connection_time = time.time()
-        _connect_cities(node_positions, outer_connection_points, connection_info, city_cells, rail_trans, grid_map)
+        inter_city_lines = _connect_cities(node_positions, outer_connection_points, connection_info, city_cells,
+                                           rail_trans, grid_map)
         print("City connection time", time.time() - city_connection_time)
         # Build inner cities
         city_build_time = time.time()
@@ -604,7 +605,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
 
         # Fix all transition elements
         grid_fix_time = time.time()
-        _fix_transitions(city_cells, grid_map)
+        _fix_transitions(city_cells, inter_city_lines, grid_map)
         print("Grid fix time", time.time() - grid_fix_time)
 
         # Generate start target pairs
@@ -728,6 +729,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         :param grid_map: Grid map
         :return:
         """
+        all_paths = []
         for current_node in np.arange(len(node_positions)):
             direction = 0
             connected_to_city = []
@@ -758,11 +760,12 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                         if tmp_dist < min_connection_dist:
                             min_connection_dist = tmp_dist
                             neighb_connection_point = tmp_in_connection_point
-                    connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
-                                   city_cells)
-
+                    new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
+                                              city_cells)
+                    all_paths.extend(new_line)
                 direction += 1
-        return
+
+        return all_paths
 
     def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans,
                             grid_map):
@@ -891,14 +894,15 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                 num_agents -= 1
         return agent_start_targets_nodes, num_agents
 
-    def _fix_transitions(city_cells, grid_map):
+    def _fix_transitions(city_cells, inter_city_lines, grid_map):
         """
         Function to fix all transition elements in environment
         """
         # Fix all nodes with illegal transition maps
         rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int')
         rails_to_fix_cnt = 0
-        for cell in city_cells:
+        cells_to_fix = city_cells + inter_city_lines
+        for cell in cells_to_fix:
             check = grid_map.cell_neighbours_valid(cell, True)
             if grid_map.grid[cell] == int('1000010000100001', 2):
                 grid_map.fix_transitions(cell)