diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 5c92fd1cde5024b330bc10050203d74cd0d74af5..6396f4123adf895c381c2a21f6d8dc6e4e823b92 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -1,3 +1,5 @@
+import time
+
 import numpy as np
 
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
@@ -30,14 +32,14 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=20,  # Number of cities in map (where train stations are)
+              rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map (where train stations are)
                                                    seed=1,  # Random seed
                                                    grid_mode=False,
                                                    max_inter_city_rails=2,
                                                    max_tracks_in_city=4,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
-              number_of_agents=50,
+              number_of_agents=10,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=GlobalObsForRailEnv())
 
@@ -111,6 +113,7 @@ for step in range(500):
     # reward and whether their are done
     next_obs, all_rewards, done, _ = env.step(action_dict)
     env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
+    time.sleep(100)
     frame_step += 1
     # Update replay buffer and train agent
     for a in range(env.get_num_agents()):
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index d0e75e27418cfb179ae907f4a3036d6d124618f4..089cea5e9f3694e92d9f84fec68b935592e7d537 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -5,6 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu
 a GridTransitionMap object.
 """
 
+import numpy as np
+
 from flatland.core.grid.grid4_astar import a_star
 from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
@@ -77,6 +79,39 @@ def connect_basic_operation(
     return path
 
 
+def connect_line(rail_trans, grid_map, start, end, openend=False):
+    # Set start cell
+    current_cell = start
+    path = [current_cell]
+    new_trans = grid_map.grid[current_cell]
+    direction = (np.clip(end[0] - start[0], -1, 1), np.clip(end[1] - start[1], -1, 1))
+    if direction[0] == 0:
+        if direction[1] > 0:
+            direction_int = 1
+        else:
+            direction_int = 3
+    else:
+        if direction[0] > 0:
+            direction_int = 2
+        else:
+            direction_int = 0
+    new_trans = rail_trans.set_transition(new_trans, direction_int, direction_int, 1)
+    new_trans = rail_trans.set_transition(new_trans, mirror(direction_int), mirror(direction_int), 1)
+    grid_map.grid[current_cell] = new_trans
+    if openend:
+        grid_map.grid[current_cell] = 0
+    # Set path
+    while current_cell != end:
+        current_cell = tuple(map(lambda x, y: x + y, current_cell, direction))
+        new_trans = grid_map.grid[current_cell]
+        new_trans = rail_trans.set_transition(new_trans, direction_int, direction_int, 1)
+        new_trans = rail_trans.set_transition(new_trans, mirror(direction_int), mirror(direction_int), 1)
+        grid_map.grid[current_cell] = new_trans
+        if current_cell == end and openend:
+            grid_map.grid[current_cell] = 0
+        path.append(current_cell)
+    return path
+
 def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                  start: IntVector2D, end: IntVector2D,
                  a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
@@ -106,3 +141,8 @@ def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap
                      start: IntVector2D, end: IntVector2D,
                      a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function)
+
+
+def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
+                         end: IntVector2D, openend=False) -> IntVector2DArray:
+    return connect_line(rail_trans, grid_map, start, end, openend)
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index d6e14459215fbcf2ea8086b3ef82833384e45b4a..6cfbece9a3203772e05ddb08dff51575f4f012d4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -10,7 +10,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail, connect_cities
+from flatland.envs.grid4_generators_utils import connect_rail, connect_cities, connect_straigt_line
 
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
@@ -552,7 +552,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         rail_array = grid_map.grid
         rail_array.fill(0)
         np.random.seed(seed + num_resets)
-        node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 2
+        node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 1
         max_inter_city_rails_allowed = max_inter_city_rails
         if max_inter_city_rails_allowed > max_tracks_in_city:
             max_inter_city_rails_allowed = max_tracks_in_city
@@ -604,7 +604,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
 
         # Fix all transition elements
         grid_fix_time = time.time()
-        _fix_transitions(grid_map)
+        _fix_transitions(city_cells, grid_map)
         print("Grid fix time", time.time() - grid_fix_time)
 
         # Generate start target pairs
@@ -791,14 +791,17 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
             opposite_boarder = (boarder + 2) % 4
             boarder_one = inner_connection_points[current_city][boarder]
             boarder_two = inner_connection_points[current_city][opposite_boarder]
-            connect_cities(rail_trans, grid_map, boarder_one[0], boarder_one[-1])
-            connect_cities(rail_trans, grid_map, boarder_two[0], boarder_two[-1])
 
+            # Connect the ends of the tracks
+            connect_straigt_line(rail_trans, grid_map, boarder_one[0], boarder_one[-1], False)
+            connect_straigt_line(rail_trans, grid_map, boarder_two[0], boarder_two[-1], False)
+
+            # Connect parallel tracks
             for track_id in range(len(inner_connection_points[current_city][boarder])):
                 if track_id % 2 == 0:
                     source = inner_connection_points[current_city][boarder][track_id]
                     target = inner_connection_points[current_city][opposite_boarder][track_id]
-                    current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
+                    current_track = connect_straigt_line(rail_trans, grid_map, source, target)
                     if target in all_outer_connection_points and source in \
                         all_outer_connection_points and len(through_path_cells[current_city]) < 1:
                         through_path_cells[current_city].extend(current_track)
@@ -806,7 +809,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                     source = inner_connection_points[current_city][opposite_boarder][track_id]
                     target = inner_connection_points[current_city][boarder][track_id]
 
-                    current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
+                    current_track = connect_straigt_line(rail_trans, grid_map, source, target)
                     if target in all_outer_connection_points and source in \
                         all_outer_connection_points and len(through_path_cells[current_city]) < 1:
                         through_path_cells[current_city].extend(current_track)
@@ -888,23 +891,22 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                 num_agents -= 1
         return agent_start_targets_nodes, num_agents
 
-    def _fix_transitions(grid_map):
+    def _fix_transitions(city_cells, grid_map):
         """
         Function to fix all transition elements in environment
         """
         # Fix all nodes with illegal transition maps
         empty_to_fix = []
         rails_to_fix = []
-        height, width = np.shape(grid_map.grid)
-        for r in range(height):
-            for c in range(width):
-                rc_pos = (r, c)
-                check = grid_map.cell_neighbours_valid(rc_pos, True)
-                if not check:
-                    if grid_map.grid[rc_pos] == 0:
-                        empty_to_fix.append(rc_pos)
-                    else:
-                        rails_to_fix.append(rc_pos)
+        for cell in city_cells:
+            check = grid_map.cell_neighbours_valid(cell, True)
+            if grid_map.grid[cell] == int('1000010000100001', 2):
+                grid_map.fix_transitions(cell)
+            if not check:
+                if grid_map.grid[cell] == 0:
+                    empty_to_fix.append(cell)
+                else:
+                    rails_to_fix.append(cell)
 
         # Fix empty cells first to avoid cutting the network
         for cell in empty_to_fix: