diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index 2057cbc1dd986d4e23df9f3d4cfeeedd91a35b6a..ee5fdc38f69356da4cabbf33fe44d28143ad46be 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -1,11 +1,13 @@
+import os
 import time
 import warnings
 
 import numpy as np
 
+from flatland.core.grid.grid4_utils import mirror
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes
+from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
@@ -13,7 +15,7 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 
-class PositionOps:
+class Vec2dOperations:
     def subtract_pos(nodeA, nodeB):
         """
         vector operation : nodeA - nodeB
@@ -69,10 +71,10 @@ class PositionOps:
             -------
         tuple with coordinate (x,y) or 2d vector
         """
-        n = PositionOps.get_norm_pos(node)
+        n = Vec2dOperations.get_norm_pos(node)
         if n > 0.0:
             n = 1 / n
-        return PositionOps.scale_pos(node, n)
+        return Vec2dOperations.scale_pos(node, n)
 
     def scale_pos(node, scalar):
         """
@@ -139,6 +141,58 @@ class PositionOps:
         return (x1, y1)
 
 
+def min_max_cut(min_v, max_v, v):
+    return max(min_v, min(max_v, v))
+
+
+def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
+    gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
+
+    lrcStroke = [[min_max_cut(0, height - 1, pt_from[0]),
+                  min_max_cut(0, width - 1, pt_from[1])],
+                 [min_max_cut(0, height - 1, pt_via[0]),
+                  min_max_cut(0, width - 1, pt_via[1])],
+                 [min_max_cut(0, height - 1, pt_to[0]),
+                  min_max_cut(0, width - 1, pt_to[1])]]
+
+    rc3Cells = np.array(lrcStroke[:3])  # the 3 cells
+    rcMiddle = rc3Cells[1]  # the middle cell which we will update
+    bDeadend = np.all(lrcStroke[0] == lrcStroke[2])  # deadend means cell 0 == cell 2
+
+    # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
+    rc2Trans = np.diff(rc3Cells, axis=0)
+
+    # get the direction index for the 2 transitions
+    liTrans = []
+    for rcTrans in rc2Trans:
+        # gRCTrans - rcTrans gives an array of vector differences between our rcTrans
+        # and the 4 directions stored in gRCTrans.
+        # Where the vector difference is zero, we have a match...
+        # np.all detects where the whole row,col vector is zero.
+        # argwhere gives the index of the zero vector, ie the direction index
+        iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
+        if len(iTrans) > 0:
+            iTrans = iTrans[0][0]
+            liTrans.append(iTrans)
+
+    # check that we have two transitions
+    if len(liTrans) == 2:
+        # Set the transition
+        # Set the transition
+        # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
+        # The user will need to resolve any conflicts.
+        grid_map.set_transition((*rcMiddle, liTrans[0]),
+                                liTrans[1],
+                                bAddRemove,
+                                remove_deadends=not bDeadend)
+
+        # Also set the reverse transition
+        # use the reversed outbound transition for inbound
+        # and the reversed inbound transition for outbound
+        grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
+                                mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
+
+
 def realistic_rail_generator(num_cities=5,
                              city_size=10,
                              allowed_rotation_angles=[0, 90],
@@ -190,22 +244,22 @@ def realistic_rail_generator(num_cities=5,
         for i in range(len(generate_city_locations)):
             # station main orientation  (horizontal or vertical
             rot_angle = np.random.choice(allowed_rotation_angles)
-            add_pos_val = PositionOps.scale_pos(PositionOps.rotate_pos((1, 0), rot_angle),
-                                                (max(1, (intern_city_size - 3) / 2)))
-            generate_city_locations[i][0] = PositionOps.add_pos(generate_city_locations[i][1], add_pos_val)
-            add_pos_val = PositionOps.scale_pos(PositionOps.rotate_pos((1, 0), 180 + rot_angle),
-                                                (max(1, (intern_city_size - 3) / 2)))
-            generate_city_locations[i][1] = PositionOps.add_pos(generate_city_locations[i][1], add_pos_val)
+            add_pos_val = Vec2dOperations.scale_pos(Vec2dOperations.rotate_pos((1, 0), rot_angle),
+                                                    (max(1, (intern_city_size - 3) / 2)))
+            generate_city_locations[i][0] = Vec2dOperations.add_pos(generate_city_locations[i][1], add_pos_val)
+            add_pos_val = Vec2dOperations.scale_pos(Vec2dOperations.rotate_pos((1, 0), 180 + rot_angle),
+                                                    (max(1, (intern_city_size - 3) / 2)))
+            generate_city_locations[i][1] = Vec2dOperations.add_pos(generate_city_locations[i][1], add_pos_val)
         return generate_city_locations
 
     def create_stations_from_city_locations(rail_trans, rail_array, generate_city_locations,
-                                            intern_max_number_of_station_tracks,
-                                            intern_nbr_of_switches_per_station_track):
+                                            intern_max_number_of_station_tracks):
         nodes_added = []
         start_nodes_added = [[] for i in range(len(generate_city_locations))]
         end_nodes_added = [[] for i in range(len(generate_city_locations))]
         station_slots = [[] for i in range(len(generate_city_locations))]
-        switch_slots = [[] for i in range(len(generate_city_locations))]
+        station_tracks = [[[] for j in range(intern_max_number_of_station_tracks)] for i in range(len(
+            generate_city_locations))]
 
         station_slots_cnt = 0
 
@@ -216,13 +270,13 @@ def realistic_rail_generator(num_cities=5,
                 org_start_node = generate_city_locations[city_loop][0]
                 org_end_node = generate_city_locations[city_loop][1]
 
-                ortho_trans = PositionOps.make_orthogonal_pos(
-                    PositionOps.normalize_pos(PositionOps.subtract_pos(org_start_node, org_end_node)))
+                ortho_trans = Vec2dOperations.make_orthogonal_pos(
+                    Vec2dOperations.normalize_pos(Vec2dOperations.subtract_pos(org_start_node, org_end_node)))
                 s = (ct - number_of_connecting_tracks / 2.0)
-                start_node = PositionOps.ceil_pos(
-                    PositionOps.add_pos(org_start_node, PositionOps.scale_pos(ortho_trans, s)))
-                end_node = PositionOps.ceil_pos(
-                    PositionOps.add_pos(org_end_node, PositionOps.scale_pos(ortho_trans, s)))
+                start_node = Vec2dOperations.ceil_pos(
+                    Vec2dOperations.add_pos(org_start_node, Vec2dOperations.scale_pos(ortho_trans, s)))
+                end_node = Vec2dOperations.ceil_pos(
+                    Vec2dOperations.add_pos(org_end_node, Vec2dOperations.scale_pos(ortho_trans, s)))
 
                 connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node)
                 if len(connection) > 0:
@@ -236,59 +290,78 @@ def realistic_rail_generator(num_cities=5,
                     station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
                     station_slots_cnt += 1
 
-                    # generate random switch positions (switch slots)
-                    if len(connection) - 3 - nbr_of_switches_per_station_track - 1> 0:
-                        idxs = np.sort(np.random.choice(np.arange(len(connection) - 3),
-                                                        nbr_of_switches_per_station_track + 1,False))
-                        idx_loop_cnt = 0
-                        for idx in idxs:
-                            pt = connection[idx + 1]
-                            if idx_loop_cnt % 2 == 1:
-                                s = (ct - number_of_connecting_tracks / 2.0)
-                                pt = PositionOps.ceil_pos(
-                                    PositionOps.add_pos(pt, PositionOps.scale_pos(ortho_trans, s)))
-                            switch_slots[city_loop].append(pt)
-                            idx_loop_cnt += 1
-
-        # generate switch based on switch slot list and connect them
-        for city_loop in range(len(switch_slots)):
-            data = switch_slots[city_loop]
-            data_idx = np.random.choice(np.arange(len(data)), len(data), False)
-            for i in range(len(data) - 1):
-                start_node = data[data_idx[i]]
-                end_node = data[data_idx[i + 1]]
-                connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node)
-                if len(connection) > 0:
-                    station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
-                    nodes_added.append(start_node)
-                    nodes_added.append(end_node)
+                    station_tracks[city_loop][ct] = connection
 
         if print_out_info:
             print("max nbr of station slots with given configuration is:", station_slots_cnt)
 
-        return nodes_added, station_slots, start_nodes_added, end_nodes_added
+        return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks
 
-    def connect_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added,
-                         inter_max_number_of_connecting_tracks, do_random_connect_stations):
+    def create_switches_at_stations(width, height, grid_map, station_tracks, nodes_added,
+                                    intern_nbr_of_switches_per_station_track):
+        # generate switch based on switch slot list and connect them
+        for city_loop in range(len(station_tracks)):
+            datas = station_tracks[city_loop]
+            for data_loop in range(len(datas) - 1):
+                data = datas[data_loop]
+                data1 = datas[data_loop + 1]
+                if len(data) > 2 and len(data1) > 2:
+                    for i in np.random.choice(min(len(data1), len(data)) - 2, 2):
+                        add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True)
+                        nodes_added.append(data[i + 1])
+                        nodes_added.append(data1[i + 1])
+                        nodes_added.append(data1[i + 2])
+
+        return nodes_added
+
+    def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added,
+                         inter_max_number_of_connecting_tracks):
+
+        s_nodes = org_s_nodes.copy()
+        e_nodes = org_e_nodes.copy()
+
+        for k in range(inter_max_number_of_connecting_tracks):
+            for city_loop in range(len(s_nodes)):
+                sns = s_nodes[city_loop]
+                cl = 0
+                min_distance = np.inf
+                end_node = None
+                start_node = None
+                for city_loop_find_shortest in range(len(e_nodes)):
+                    if city_loop_find_shortest == city_loop:
+                        continue
+                    ens = e_nodes[city_loop_find_shortest]
+                    for en in ens:
+                        for sn in sns:
+                            d = Vec2dOperations.get_norm_pos(Vec2dOperations.subtract_pos(en, sn))
+                            if d < min_distance:
+                                min_distance = d
+                                end_node = en
+                                start_node = sn
+                                cl = city_loop_find_shortest
+
+                if end_node is not None:
+                    tmp_trans_sn = rail_array[start_node]
+                    tmp_trans_en = rail_array[end_node]
+                    rail_array[start_node] = 0
+                    rail_array[end_node] = 0
+                    connection = connect_rail(rail_trans, rail_array, start_node, end_node)
+                    if len(connection) > 0:
+                        s_nodes[city_loop].remove(start_node)
+                        e_nodes[cl].remove(end_node)
+                        nodes_added.append(start_node)
+                        nodes_added.append(end_node)
+                    else:
+                        rail_array[start_node] = tmp_trans_sn
+                        rail_array[end_node] = tmp_trans_en
+
+    def connect_random_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added,
+                                inter_max_number_of_connecting_tracks):
         x = np.arange(len(start_nodes_added))
-        if do_random_connect_stations:
-            random_city_idx = np.random.choice(x, len(x), False)
-        else:
-            a = [[] for i in x]
-            b = []
-            for yLoop in x:
-                for xLoop in x:
-                    v = PositionOps.get_norm_pos(
-                        PositionOps.subtract_pos(start_nodes_added[xLoop][0], end_nodes_added[yLoop][0]))
-                    if v > 0:
-                        v = np.inf
-                    a[yLoop].append(v)
-            for i in range(len(a)):
-                b.append(np.argmin(a[i]))
-            random_city_idx = np.argsort(b)
+        random_city_idx = np.random.choice(x, len(x), False)
 
         # cyclic connection
-        random_city_idx = np.append(random_city_idx,random_city_idx[0])
+        random_city_idx = np.append(random_city_idx, random_city_idx[0])
 
         for city_loop in range(len(random_city_idx) - 1):
             idx_a = random_city_idx[city_loop + 1]
@@ -377,16 +450,23 @@ def realistic_rail_generator(num_cities=5,
 
         # ----------------------------------------------------------------------------------
         # generate city topology
-        nodes_added, train_stations, s_nodes, e_nodes = \
+        nodes_added, train_stations, s_nodes, e_nodes, station_tracks = \
             create_stations_from_city_locations(rail_trans, rail_array,
                                                 generate_city_locations,
-                                                intern_max_number_of_station_tracks,
-                                                intern_nbr_of_switches_per_station_track)
+                                                intern_max_number_of_station_tracks)
+        # build switches
+        create_switches_at_stations(width, height, grid_map, station_tracks, nodes_added,
+                                    intern_nbr_of_switches_per_station_track)
+
         # ----------------------------------------------------------------------------------
         # connect stations
-        connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added,
-                         inter_max_number_of_connecting_tracks,
-                         do_random_connect_stations)
+        if True:
+            if do_random_connect_stations:
+                connect_random_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added,
+                                        inter_max_number_of_connecting_tracks)
+            else:
+                connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added,
+                                 inter_max_number_of_connecting_tracks)
 
         # ----------------------------------------------------------------------------------
         # fix all transition at starting / ending points (mostly add a dead end, if missing)
@@ -441,14 +521,14 @@ def realistic_rail_generator(num_cities=5,
 for itrials in range(100):
     print(itrials, "generate new city")
     np.random.seed(int(time.time()))
-    env = RailEnv(width=20+np.random.choice(100),
-                  height=20+np.random.choice(100),
-                  rail_generator=realistic_rail_generator(num_cities=2+np.random.choice(10),
-                                                          city_size=4+np.random.choice(20),
-                                                          allowed_rotation_angles=[-90,-30,0,30,90],
+    env = RailEnv(width=100,  # 20+np.random.choice(100),
+                  height=100,  # 20+np.random.choice(100),
+                  rail_generator=realistic_rail_generator(num_cities=100,
+                                                          city_size=20,
+                                                          allowed_rotation_angles=[-90, 0, 90],
                                                           max_number_of_station_tracks=4,
                                                           nbr_of_switches_per_station_track=2,
-                                                          max_number_of_connecting_tracks=3,
+                                                          max_number_of_connecting_tracks=4,
                                                           do_random_connect_stations=False,
                                                           # Number of cities in map
                                                           seed=int(time.time())  # Random seed
@@ -463,4 +543,11 @@ for itrials in range(100):
     while cnt < 10:
         env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
         cnt += 1
+
+    env_renderer.gl.save_image(
+        os.path.join(
+            "./../render_output/",
+            "flatland_frame_{:04d}_{:04d}.png".format(itrials, 0)
+        ))
+
     env_renderer.close_window()