diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 44a086b770fb3b276ab87e825f7d10802a86904b..53105ab15dbbbf911be688d05b0cad4fcc4a82ef 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -36,12 +36,12 @@ env = RailEnv(width=50,
                                                    num_trainstations=100,  # Number of possible start/targets on map
                                                    min_node_dist=10,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
-                                                   num_neighb=2,  # Number of connections to other cities/intersections
+                                                   num_neighb=3,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
                                                    grid_mode=True,
                                                    nr_parallel_tracks=2,
-                                                   connectin_points_per_side=5,
-                                                   max_nr_connection_directions=2,
+                                                   connectin_points_per_side=2,
+                                                   max_nr_connection_directions=3,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
               number_of_agents=50,
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 33fc408d67e52557aa89b5f8c8526557dd608b2a..fdf6e2167fe1d7e4f698e5fdf5223d92ed26ec60 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -591,22 +591,154 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
 
         # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
         nb_nodes = len(node_positions)
-        _num_cities = len(city_positions)
-
-        # Chose node connection
-        # Set up list of available nodes to connect to
-        available_nodes = np.arange(nb_nodes)
 
         # Set up connection points for all cities
-        connection_points = _generate_node_connection_points(node_positions, node_radius, connectin_points_per_side,
-                                                             max_nr_connection_directions)
+        connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius,
+                                                                              connectin_points_per_side,
+                                                                              max_nr_connection_directions)
+
+        # Connect the cities through the connection points
+        _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map)
+
+        # Build inner cities
+        train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map)
+
+        # Adjust the number of agents if you could not build enough trainstations
+        if num_agents > built_num_trainstation:
+            num_agents = built_num_trainstation
+            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
+
+        # Fix all transition elements
+        _fix_transitions(grid_map)
+
+        # Generate start target paris
+        agent_start_targets_nodes = _generate_start_target_pairs(num_agents, nb_nodes, train_stations)
+
+        return grid_map, {'agents_hints': {
+            'num_agents': num_agents,
+            'agent_start_targets_nodes': agent_start_targets_nodes,
+            'train_stations': train_stations
+        }}
+
+    def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
+                                               width):
+
+        node_positions = []
+        for node_idx in range(nb_nodes):
+            to_close = True
+            tries = 0
+
+            while to_close:
+                x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1)
+                y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1)
+                to_close = False
+
+                # Check distance to cities
+                for node_pos in city_positions:
+                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                        to_close = True
+
+                # Check distance to intersections
+                for node_pos in intersection_positions:
+                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                        to_close = True
+
+                if not to_close:
+                    node_positions.append((x_tmp, y_tmp))
+                    if node_idx < num_cities:
+                        city_positions.append((x_tmp, y_tmp))
+                    else:
+                        intersection_positions.append((x_tmp, y_tmp))
+                tries += 1
+                if tries > 100:
+                    warnings.warn(
+                        "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
+                            len(node_positions),
+                            tries, nb_nodes))
+                    break
+
+        node_positions = city_positions + intersection_positions
+        return node_positions
+
+    def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
+                                           nodes_per_row, x_positions, y_positions):
+
+        for node_idx in range(nb_nodes):
+
+            x_tmp = x_positions[node_idx % nodes_per_row]
+            y_tmp = y_positions[node_idx // nodes_per_row]
+            if node_idx in city_idx:
+                city_positions.append((x_tmp, y_tmp))
+
+            else:
+                intersection_positions.append((x_tmp, y_tmp))
+        node_positions = city_positions + intersection_positions
+        return node_positions
+
+    def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2,
+                                         max_nr_connection_directions=2):
+        connection_points = []
+        connection_info = []
+        for node_position in node_positions:
+
+            connection_sides_idx = np.sort(
+                np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False))
+
+            # Chose the directions where close cities are situated
+            neighb_dist = []
+            for neighb_node in node_positions:
+                neighb_dist.append(distance_on_rail(node_position, neighb_node))
+            closest_neighb_idx = argsort(neighb_dist)
+
+            # Store the directions to these neighbours
+            connection_sides_idx = []
+            for idx in range(1, max_nr_connection_directions + 1):
+                connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]]))
+
+            # set the number of connection points for each direction
+            connections_per_direction = np.zeros(4, dtype=int)
+
+            for idx in connection_sides_idx:
+                connections_per_direction[idx] = max_nr_connection_points
+            connection_points_coordinates = []
+
+            for direction in range(4):
+                connection_slots = np.arange(connections_per_direction[direction]) - int(
+                        connections_per_direction[direction] / 2)
+                for connection_idx in range(connections_per_direction[direction]):
+                    if direction == 0:
+                        connection_points_coordinates.append(
+                            (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
+                    if direction == 1:
+                        connection_points_coordinates.append(
+                            (node_position[0] + connection_slots[connection_idx], node_position[1] + node_size))
+                    if direction == 2:
+                        connection_points_coordinates.append(
+                            (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
+                    if direction == 3:
+                        connection_points_coordinates.append(
+                            (node_position[0] + connection_slots[connection_idx], node_position[1] - node_size))
+
+            connection_points.append(connection_points_coordinates)
+            connection_info.append(connections_per_direction)
+        return connection_points, connection_info
+
+    def _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map):
+        """
+        Function to connect the different cities through their connection points
+        :param node_positions: Positions of city centers
+        :param connection_points: Boarder connection points of cities
+        :param connection_info: Number of connection points per direction NESW
+        :param rail_trans: Transitions
+        :param grid_map: Grid map
+        :return:
+        """
 
         # Start at some node
+        available_nodes = np.arange(len(node_positions))
         current_node = np.random.randint(len(available_nodes))
         node_stack = [current_node]
         open_nodes = np.copy(available_nodes)
-        allowed_connections = num_neighb
-        i = 0
         boarder_connections = set()
         while len(open_nodes) > 0:
             if len(node_stack) > 0:
@@ -623,7 +755,9 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                 node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
             available_nodes = available_nodes[np.argsort(node_dist)]
 
-            # Set number of neighboring nodes
+            # Set number of neighboring
+            allowed_connections = np.count_nonzero(connection_info[current_node])
+
             if len(available_nodes) >= allowed_connections:
                 connected_neighb_idx = available_nodes[1:allowed_connections + 1]
             else:
@@ -649,16 +783,18 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                         if tmp_dist < min_connection_dist:
                             min_connection_dist = tmp_dist
                             neighb_connection_point = tmp_in_connection_point
-                    i += 1
                     connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
                     boarder_connections.add((tmp_out_connection_point, current_node))
                     boarder_connections.add((neighb_connection_point, neighb))
 
             node_stack.pop(0)
 
+    def _build_cities(node_positions, connection_points, rail_trans, grid_map):
         # Place train stations close to the node
         # We currently place them uniformly distributed among all cities
         built_num_trainstation = 0
+        nb_nodes = len(node_positions)
+        height, width = np.shape(grid_map.grid)
         train_stations = [[] for i in range(nb_nodes)]
         if nb_nodes > 1:
 
@@ -705,12 +841,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                     connection = connect_nodes(rail_trans, grid_map,
                                                connection_points[trainstation_node][corner_node_idx],
                                                (station_x, station_y))
-                    if len(connection) != 0:
-                        if (connection_points[trainstation_node][corner_node_idx],
-                            trainstation_node) in boarder_connections:
-                            boarder_connections.remove(
-                                (connection_points[trainstation_node][corner_node_idx], trainstation_node))
-
                 # Check if connection was made
                 if len(connection) == 0:
                     if len(train_stations[trainstation_node]) > 0:
@@ -718,40 +848,16 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                 else:
 
                     built_num_trainstation += 1
-        # Adjust the number of agents if you could not build enough trainstations
-        if num_agents > built_num_trainstation:
-            num_agents = built_num_trainstation
-            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
-
-        # Connect all disjunct parts of the network
-
-        if len(boarder_connections) > 0:
-            to_be_deleted = []
-            for disjunct_node in boarder_connections:
-                if len(train_stations[disjunct_node[1]]) > 0:
-                    conn = connect_nodes(rail_trans, grid_map,
-                                         disjunct_node[0],
-                                         train_stations[disjunct_node[1]][-1])
-                else:
-                    conn = connect_nodes(rail_trans, grid_map,
-                                         disjunct_node[0],
-                                         node_positions[disjunct_node[1]])
-                if len(conn) > 0:
-                    to_be_deleted.append(disjunct_node)
-                else:
-                    conn = connect_nodes(rail_trans, grid_map,
-                                         disjunct_node[0],
-                                         node_positions[disjunct_node[1]])
-                    if len(conn) > 0:
-                        to_be_deleted.append(disjunct_node)
-
-            for tbd in to_be_deleted:
-                boarder_connections.remove(tbd)
-            print(boarder_connections)
+        return train_stations, built_num_trainstation
 
+    def _fix_transitions(grid_map):
+        """
+        Function to fix all transition elements in environment
+        """
         # Fix all nodes with illegal transition maps
         empty_to_fix = []
         rails_to_fix = []
+        height, width = np.shape(grid_map.grid)
         for r in range(height):
             for c in range(width):
                 rc_pos = (r, c)
@@ -770,6 +876,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         for cell in rails_to_fix:
             grid_map.fix_transitions(cell)
 
+    def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
+
         # Generate start and target node directory for all agents.
         # Assure that start and target are not in the same node
         agent_start_targets_nodes = []
@@ -805,116 +913,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                 agent_start_targets_nodes.append((start_node, target_node))
             else:
                 num_agents -= 1
-
-        return grid_map, {'agents_hints': {
-            'num_agents': num_agents,
-            'agent_start_targets_nodes': agent_start_targets_nodes,
-            'train_stations': train_stations
-        }}
-
-    def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
-                                               width):
-
-        node_positions = []
-        for node_idx in range(nb_nodes):
-            to_close = True
-            tries = 0
-
-            while to_close:
-                x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1)
-                y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1)
-                to_close = False
-
-                # Check distance to cities
-                for node_pos in city_positions:
-                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
-                        to_close = True
-
-                # Check distance to intersections
-                for node_pos in intersection_positions:
-                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
-                        to_close = True
-
-                if not to_close:
-                    node_positions.append((x_tmp, y_tmp))
-                    if node_idx < num_cities:
-                        city_positions.append((x_tmp, y_tmp))
-                    else:
-                        intersection_positions.append((x_tmp, y_tmp))
-                tries += 1
-                if tries > 100:
-                    warnings.warn(
-                        "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
-                            len(node_positions),
-                            tries, nb_nodes))
-                    break
-
-        node_positions = city_positions + intersection_positions
-        return node_positions
-
-    def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
-                                           nodes_per_row, x_positions, y_positions):
-
-        for node_idx in range(nb_nodes):
-
-            x_tmp = x_positions[node_idx % nodes_per_row]
-            y_tmp = y_positions[node_idx // nodes_per_row]
-            if node_idx in city_idx:
-                city_positions.append((x_tmp, y_tmp))
-
-            else:
-                intersection_positions.append((x_tmp, y_tmp))
-        node_positions = city_positions + intersection_positions
-        return node_positions
-
-    def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2,
-                                         max_nr_connection_directions=2):
-        connection_points = []
-        for node_position in node_positions:
-
-            connection_sides_idx = np.sort(
-                np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False))
-
-            # Chose the directions where close cities are situated
-            neighb_dist = []
-            for neighb_node in node_positions:
-                neighb_dist.append(distance_on_rail(node_position, neighb_node))
-            closest_neighb_idx = argsort(neighb_dist)
-            connection_sides_idx = []
-            for idx in range(1, max_nr_connection_directions + 1):
-                connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]]))
-
-            connections_per_direction = np.zeros(4, dtype=int)
-            # set the number of connection points for each direction
-            for idx in connection_sides_idx:
-                connections_per_direction[idx] = max_nr_connection_points
-            connection_points_coordinates = []
-            random_connection_slots = False
-            for direction in range(4):
-                if random_connection_slots:
-                    connection_slots = np.random.choice(np.arange(-node_size, node_size),
-                                                        size=connections_per_direction[direction],
-                                                        replace=False)
-                else:
-                    connection_slots = np.arange(connections_per_direction[direction]) - int(
-                        connections_per_direction[direction] / 2)
-                for connection_idx in range(connections_per_direction[direction]):
-                    if direction == 0:
-                        connection_points_coordinates.append(
-                            (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
-                    if direction == 1:
-                        connection_points_coordinates.append(
-                            (node_position[0] + connection_slots[connection_idx], node_position[1] + node_size))
-                    if direction == 2:
-                        connection_points_coordinates.append(
-                            (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
-                    if direction == 3:
-                        connection_points_coordinates.append(
-                            (node_position[0] + connection_slots[connection_idx], node_position[1] - node_size))
-
-            connection_points.append(connection_points_coordinates)
-        return connection_points
-
+        return agent_start_targets_nodes
     def argsort(seq):
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
         return sorted(range(len(seq)), key=seq.__getitem__)