diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 1520e598de6a351f87367ccba2fc853a86ab64e7..2dc6efa467dcd3a8d1856a968d43b5b4f4dc4b09 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -34,7 +34,7 @@ env = RailEnv(width=50,
               height=50,
               rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
                                                    num_intersections=0,  # Number of intersections (no start / target)
-                                                   num_trainstations=50,  # Number of possible start/targets on map
+                                                   num_trainstations=10,  # Number of possible start/targets on map
                                                    min_node_dist=3,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
                                                    num_neighb=4,  # Number of connections to other cities/intersections
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index e4a7e292b3c4358149e04ddfc1ed2733c8b68b3c..937b8c88feb086627a0fa14d4c73d98e7bac8d9f 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -667,8 +667,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                         if distance_on_rail(tmp_out_connection_point, neighb_connection_point) < center_distance:
                             i += 1
                             connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
-                            boarder_connections.add(tmp_out_connection_point)
-                            boarder_connections.add(neighb_connection_point)
+                            boarder_connections.add((tmp_out_connection_point, current_node))
+                            boarder_connections.add((neighb_connection_point, neighb))
 
             node_stack.pop(0)
 
@@ -722,10 +722,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                                                connection_points[trainstation_node][corner_node_idx],
                                                (station_x, station_y))
                     if len(connection) != 0:
-                        if connection_points[trainstation_node][corner_node_idx] in boarder_connections:
-                            boarder_connections.remove(connection_points[trainstation_node][corner_node_idx])
+                        if (connection_points[trainstation_node][corner_node_idx],
+                            trainstation_node) in boarder_connections:
+                            boarder_connections.remove(
+                                (connection_points[trainstation_node][corner_node_idx], trainstation_node))
 
-                grid_map.fix_transitions((station_x, station_y))
                 # Check if connection was made
                 if len(connection) == 0:
                     if len(train_stations[trainstation_node]) > 0:
@@ -737,8 +738,30 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             num_agents = built_num_trainstation
             warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
 
+        # Connect all disjunct parts of the network
+
+        if len(boarder_connections) > 0:
+            to_be_deleted = []
+            for disjunct_node in boarder_connections:
+                print(disjunct_node)
+                conn = connect_nodes(rail_trans, grid_map,
+                                     disjunct_node[0],
+                                     train_stations[disjunct_node[1]][0])
+                if len(conn) > 0:
+                    to_be_deleted.append(disjunct_node)
+
+        for tbd in to_be_deleted:
+            boarder_connections.remove(tbd)
+        print(boarder_connections)
         # Fix all nodes with illegal transition maps
+        flat_trainstation_list = [item for sublist in train_stations for item in sublist]
+        for cell_to_fix in flat_trainstation_list:
+            grid_map.fix_transitions(cell_to_fix)
+
+        grid_map.fix_transitions((station_x, station_y))
+
         flat_list = [item for sublist in connection_points for item in sublist]
+
         for cell_to_fix in flat_list:
             grid_map.fix_transitions(cell_to_fix)