diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index ca9346fc71a8d5485d7b71098dc91558d09bd929..44a086b770fb3b276ab87e825f7d10802a86904b 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -40,7 +40,7 @@ env = RailEnv(width=50,
                                                    seed=15,  # Random seed
                                                    grid_mode=True,
                                                    nr_parallel_tracks=2,
-                                                   connectin_points_per_side=3,
+                                                   connectin_points_per_side=5,
                                                    max_nr_connection_directions=2,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py
index d39fc8a771cf7f78d9203f2f92632694be26ed91..7c34796cb045165819cc442590565c69d17cad91 100644
--- a/flatland/core/grid/grid_utils.py
+++ b/flatland/core/grid/grid_utils.py
@@ -296,3 +296,25 @@ def coordinate_to_position(depth, coords):
 
 def distance_on_rail(pos1, pos2):
     return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
+
+
+def closest_direction(pos1, pos2):
+    """
+    Returns the closest direction orientation of position 2 relative to position 1
+    :param pos1: position we are interested in
+    :param pos2: position we want to know it is facing
+    :return: direction NESW as int N:0 E:1 S:2 W:3
+    """
+    diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1]))
+    axis = np.argmax(np.power(diff_vec, 2))
+    direction = np.sign(diff_vec[axis])
+    if axis == 0:
+        if direction > 0:
+            return 2
+        else:
+            return 0
+    else:
+        if direction > 0:
+            return 3
+        else:
+            return 1
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 117862a71593f8483928323574d8977898cfdc9c..33fc408d67e52557aa89b5f8c8526557dd608b2a 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -6,7 +6,7 @@ import msgpack
 import numpy as np
 
 from flatland.core.grid.grid4_utils import get_direction, mirror
-from flatland.core.grid.grid_utils import distance_on_rail
+from flatland.core.grid.grid_utils import distance_on_rail, closest_direction
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes
@@ -871,10 +871,20 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                                          max_nr_connection_directions=2):
         connection_points = []
         for node_position in node_positions:
+
             connection_sides_idx = np.sort(
                 np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False))
-            connections_per_direction = np.zeros(4, dtype=int)
 
+            # Chose the directions where close cities are situated
+            neighb_dist = []
+            for neighb_node in node_positions:
+                neighb_dist.append(distance_on_rail(node_position, neighb_node))
+            closest_neighb_idx = argsort(neighb_dist)
+            connection_sides_idx = []
+            for idx in range(1, max_nr_connection_directions + 1):
+                connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]]))
+
+            connections_per_direction = np.zeros(4, dtype=int)
             # set the number of connection points for each direction
             for idx in connection_sides_idx:
                 connections_per_direction[idx] = max_nr_connection_points