diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 53105ab15dbbbf911be688d05b0cad4fcc4a82ef..fc358a2a7da1367c05fe2a52f4ce35a6c551c588 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -32,15 +32,15 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
+              rail_generator=sparse_rail_generator(num_cities=3,  # Number of cities in map (where train stations are)
                                                    num_trainstations=100,  # Number of possible start/targets on map
                                                    min_node_dist=10,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
                                                    num_neighb=3,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
-                                                   grid_mode=True,
+                                                   grid_mode=False,
                                                    nr_parallel_tracks=2,
-                                                   connectin_points_per_side=2,
+                                                   connectin_points_per_side=100,
                                                    max_nr_connection_directions=3,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 853a6376126f8becde8a1b0821bc3ba778910edf..b7ca0f4d94d2bb1ce7e9bbfd8930662ec0893aa1 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -683,11 +683,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                                          max_nr_connection_directions=2):
         connection_points = []
         connection_info = []
-        for node_position in node_positions:
-
-            connection_sides_idx = np.sort(
-                np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False))
+        max_nr_connection_directions = np.clip(max_nr_connection_directions, 0, 4)
+        if max_nr_connection_points > 2 * node_size + 1:
+            max_nr_connection_points = 2 * node_size + 1
 
+        for node_position in node_positions:
             # Chose the directions where close cities are situated
             neighb_dist = []
             for neighb_node in node_positions:
@@ -696,7 +696,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
 
             # Store the directions to these neighbours
             connection_sides_idx = []
-            for idx in range(1, max_nr_connection_directions + 1):
+            for idx in range(1, min(len(neighb_dist) - 1, max_nr_connection_directions) + 1):
                 connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]]))
 
             # set the number of connection points for each direction
@@ -918,6 +918,35 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
             else:
                 num_agents -= 1
         return agent_start_targets_nodes
+
+    def _closest_neigh_in_direction(current_node, direction, node_positions):
+        # Sort available neighbors according to their distance.
+        available_nodes = np.arange(node_positions)
+        node_dist = []
+        for av_node in available_nodes:
+            node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
+        sorted_neighbours = available_nodes[np.argsort(node_dist)]
+
+        for neighb in sorted_neighbours[1:]:
+            distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0])
+            distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1])
+            if direction == 0:
+                if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
+                    return neighb
+
+            if direction == 1:
+                if node_positions[neighb][1] > node_positions[current_node][1] and distance_0 <= distance_1:
+                    return neighb
+
+            if direction == 2:
+                if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
+                    return neighb
+
+            if direction == 3:
+                if node_positions[neighb][0] < node_positions[current_node][0] and distance_0 <= distance_1:
+                    return neighb
+        return None
+
     def argsort(seq):
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
         return sorted(range(len(seq)), key=seq.__getitem__)