diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 0bbac92e50b0224f765c77935669328150ab1348..f3bfef2874f8170af2386dc20c9a8990840dd84d 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -36,7 +36,7 @@ env = RailEnv(width=50,
                                                    # Number of cities in map (where train stations are)
                                                    seed=1,  # Random seed
                                                    grid_mode=False,
-                                                   max_rails_between_cities=2,
+                                                   max_rails_between_cities=3,
                                                    max_rails_in_city=6,
                                                    ),
               schedule_generator=sparse_schedule_generator(speed_ration_map),
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 2a3db9df346faea5a8734e3cb4a78194133cb319..d65663d1a9ac63169cae10c35cb5e909fe32d338 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -561,11 +561,11 @@ class GridTransitionMap(TransitionMap):
         if number_of_incoming == 3:
             self.set_transitions(rcPos, 0)
             hole = np.argwhere(incoming_connections < 1)[0][0]
-            if direction > 0:
+            if direction >= 0:
                 switch_type_idx = (direction - hole + 3) % 4
-                if switch_type_idx == 2:
-                    transition = simple_switch_west_south
                 if switch_type_idx == 0:
+                    transition = simple_switch_west_south
+                elif switch_type_idx == 2:
                     transition = simple_switch_east_south
                 else:
                     transition = np.random.choice(three_way_transitions, 1)
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index 04903d3ad02ec41e9615c2bb488f24c01e163a1d..37fdcf5d5055e9b17310a06ddd3be13f14b8d1c7 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -9,7 +9,7 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_astar import a_star
-from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
+from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point, get_new_position
 from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
@@ -126,3 +126,33 @@ def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVec
         grid_map.grid[cell] = transition
 
     return path
+
+
+def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, rail_trans: RailEnvTransitions):
+    """
+    Fix inner city nodes
+    :param grid_map:
+    :param start:
+    :param rail_trans:
+    :return:
+    """
+    corner_directions = []
+    for direction in range(4):
+        tmp_pos = get_new_position(inner_node_pos, direction)
+        if grid_map.grid[tmp_pos] > 0:
+            corner_directions.append(direction)
+    if len(corner_directions) == 2:
+        transition = 0
+        transition = rail_trans.set_transition(transition, mirror(corner_directions[0]), corner_directions[1], 1)
+        transition = rail_trans.set_transition(transition, mirror(corner_directions[1]), corner_directions[0], 1)
+        grid_map.grid[inner_node_pos] = transition
+        tmp_pos = get_new_position(inner_node_pos, corner_directions[0])
+        transition = grid_map.grid[tmp_pos]
+        transition = rail_trans.set_transition(transition, corner_directions[0], mirror(corner_directions[0]), 1)
+        grid_map.grid[tmp_pos] = transition
+        tmp_pos = get_new_position(inner_node_pos, corner_directions[1])
+        transition = grid_map.grid[tmp_pos]
+        transition = rail_trans.set_transition(transition, corner_directions[1], mirror(corner_directions[1]),
+                                               1)
+        grid_map.grid[tmp_pos] = transition
+    return
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index d32d1521c9fa98918a1744f6e5874c6ae88bbbcf..324e21b633b272e72c67e0a4b2b88aa024dcdaa4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -12,7 +12,8 @@ from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, In
     Vec2dOperations
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map
+from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
+    fix_inner_nodes
 
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
@@ -698,22 +699,35 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
             for direction in range(4):
                 connection_slots = np.arange(nr_of_connection_points) - start_idx
+                inner_point_offset = np.abs(connection_slots) + np.clip(connection_slots, 0, 1)
                 for connection_idx in range(connections_per_direction[direction]):
                     if direction == 0:
                         tmp_coordinates = (
+                            city_position[0] - city_radius + inner_point_offset[connection_idx],
+                            city_position[1] + connection_slots[connection_idx])
+                        out_tmp_coordinates = (
                             city_position[0] - city_radius, city_position[1] + connection_slots[connection_idx])
                     if direction == 1:
                         tmp_coordinates = (
+                            city_position[0] + connection_slots[connection_idx],
+                            city_position[1] + city_radius - inner_point_offset[connection_idx])
+                        out_tmp_coordinates = (
                             city_position[0] + connection_slots[connection_idx], city_position[1] + city_radius)
                     if direction == 2:
                         tmp_coordinates = (
+                            city_position[0] + city_radius - inner_point_offset[connection_idx],
+                            city_position[1] + connection_slots[connection_idx])
+                        out_tmp_coordinates = (
                             city_position[0] + city_radius, city_position[1] + connection_slots[connection_idx])
                     if direction == 3:
                         tmp_coordinates = (
+                            city_position[0] + connection_slots[connection_idx],
+                            city_position[1] - city_radius + inner_point_offset[connection_idx])
+                        out_tmp_coordinates = (
                             city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius)
                     connection_points_coordinates_inner[direction].append(tmp_coordinates)
                     if connection_idx in range(start_idx, start_idx + number_of_out_rails):
-                        connection_points_coordinates_outer[direction].append(tmp_coordinates)
+                        connection_points_coordinates_outer[direction].append(out_tmp_coordinates)
 
             inner_connection_points.append(connection_points_coordinates_inner)
             outer_connection_points.append(connection_points_coordinates_outer)
@@ -789,11 +803,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         :param grid_map:
         :return: Returns the cells of the through path which cannot be occupied by trainstations
         """
-        through_path_cells: List[IntVector2DArray] = [[] for i in range(len(city_positions))]
         free_rails: List[List[List[IntVector2D]]] = [[] for i in range(len(city_positions))]
         for current_city in range(len(city_positions)):
-            all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in
-                                           sublist]
+
             # This part only works if we have keep same number of connection points for both directions
             # Also only works with two connection direction at each city
             for i in range(4):
@@ -802,18 +814,29 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                     break
 
             opposite_boarder = (boarder + 2) % 4
-            boarder_one = inner_connection_points[current_city][boarder]
-            boarder_two = inner_connection_points[current_city][opposite_boarder]
-
-            # Connect the ends of the tracks
-            connect_straight_line_in_grid_map(grid_map, boarder_one[0], boarder_one[-1], rail_trans)
-            connect_straight_line_in_grid_map(grid_map, boarder_two[0], boarder_two[-1], rail_trans)
-
+            nr_of_connection_points = len(inner_connection_points[current_city][boarder])
+            number_of_out_rails = len(outer_connection_points[current_city][boarder])
+            start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
             # Connect parallel tracks
-            for track_id in range(len(inner_connection_points[current_city][boarder])):
+            for track_id in range(nr_of_connection_points):
                 source = inner_connection_points[current_city][boarder][track_id]
                 target = inner_connection_points[current_city][opposite_boarder][track_id]
                 current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans)
+
+            for track_id in range(nr_of_connection_points):
+                source = inner_connection_points[current_city][boarder][track_id]
+                target = inner_connection_points[current_city][opposite_boarder][track_id]
+                fix_inner_nodes(
+                    grid_map, source, rail_trans)
+                fix_inner_nodes(
+                    grid_map, target, rail_trans)
+                if start_idx <= track_id < start_idx + number_of_out_rails:
+                    source_outer = outer_connection_points[current_city][boarder][track_id - start_idx]
+                    target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx]
+                    connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans)
+                    connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans)
+
+
                 free_rails[current_city].append(current_track)
         return free_rails
 
@@ -870,18 +893,20 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         cells_to_fix = city_cells + inter_city_lines
         for cell in cells_to_fix:
             cell_valid = grid_map.cell_neighbours_valid(cell, True)
-            if grid_map.grid[cell] == int('1000010000100001', 2):
-                grid_map.fix_transitions(cell)
+            # cell_valid = grid_map.transitions.is_valid(cell)
+            # if grid_map.grid[cell] == int('1000010000100001', 2):
+            #    grid_map.fix_transitions(cell)
+            # if bin(grid_map.grid[cell]).count("1") == 4:
+            #    cell_valid = False
+            #    print("fixing cell", cell, vector_field[cell])
             if not cell_valid:
                 rails_to_fix[3 * rails_to_fix_cnt] = cell[0]
                 rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1]
-                rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[(cell[0], cell[1])]
+                rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell]
                 rails_to_fix_cnt += 1
-
         # Fix all other cells
         for cell in range(rails_to_fix_cnt):
-            grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]),
-                                     rails_to_fix[3 * rails_to_fix_cnt + 2])
+            grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2])
 
     def _closest_neighbour_in_grid4_directions(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
         """
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index 699f1e91cce54d6486b5736a4884a4a5d804f1f6..08b8bdc9789df1a178b997860521630c468373d5 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -570,8 +570,8 @@ def test_sparse_rail_generator():
     for a in range(env.get_num_agents()):
         s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0))
         s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0))
-    assert s0 == 61, "actual={}".format(s0)
-    assert s1 == 42, "actual={}".format(s1)
+    assert s0 == 39, "actual={}".format(s0)
+    assert s1 == 27, "actual={}".format(s1)
 
 
 def test_sparse_rail_generator_deterministic():
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index ace603422006d7f6229e854e24b6f244b7fe183c..7cf6812654a8d26fd19dd174f07bc0901785463b 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -158,7 +158,7 @@ def test_malfunction_process_statistically():
         env.step(action_dict)
 
     # check that generation of malfunctions works as expected
-    assert nb_malfunction == 128, "nb_malfunction={}".format(nb_malfunction)
+    assert nb_malfunction == 152, "nb_malfunction={}".format(nb_malfunction)
 
 
 def test_initial_malfunction():