diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 8880d27b852f7d2d54d93ffd23e42156bced0b74..ba12736f6d3ad241f6b323bf1ef7204b036e3668 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -744,25 +744,16 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         all_paths = []
 
         for current_node in np.arange(len(node_positions)):
-            direction = 0
-            connected_to_city = []
             neighbours = _closest_neigh_in_direction(current_node, node_positions)
-            for nbr_connection_points in connection_info[current_node]:
-                if nbr_connection_points > 0:
-                    neighb_idx = neighbours[direction]
-                else:
-                    direction += 1
-                    continue
-
-                # If no closest neighbour was found look at the neighbouring connections
-                tmp_direction = (direction - 1) % 4
-                while neighb_idx is None:
-                    neighb_idx = neighbours[tmp_direction]
-                    tmp_direction = (direction + 1) % 4
-
-                connected_to_city.append(neighb_idx)
-                for tmp_out_connection_point in connection_points[current_node][direction]:
-                    # Find closest connection point
+            for out_direction in range(4):
+                for tmp_out_connection_point in connection_points[current_node][out_direction]:
+                    # This only needs to be checked when entering this loop
+                    neighb_idx = neighbours[out_direction]
+                    if neighb_idx is None:
+                        tmp_direction = (out_direction - 1) % 4
+                    while neighb_idx is None:
+                        neighb_idx = neighbours[tmp_direction]
+                        tmp_direction = (out_direction + 1) % 4
                     min_connection_dist = np.inf
                     for dir in range(4):
                         current_points = connection_points[neighb_idx][dir]
@@ -776,13 +767,11 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                     new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point,
                                               neighb_connection_point,
                                               city_cells)
-                    G.add_edge(current_node, neighb_idx, direction=direction, length=len(new_line))
+                    G.add_edge(current_node, neighb_idx, direction=out_direction, length=len(new_line))
                     G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line))
 
                     all_paths.extend(new_line)
 
-                direction += 1
-
         return all_paths
 
     def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans,