diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 929162114f5de2aeb10bb846c1ca179cca14f822..6dbcd07f690e6e138cfee3354a927b887a915ed9 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -543,7 +543,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
     return generator
 
 
-def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
+def realistic_rail_generator(nr_start_goal=1,  seed=0):
     """
     Parameters
     -------
@@ -572,18 +572,19 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
                        int('1100000000100010', 2)]  # Case 2b (10) - simple switch mirrored
 
     """
-    def min_max_cut(min_v,max_v,v):
-        return max(min_v,min(max_v,v))
 
-    def add_rail(width,height,grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
+    def min_max_cut(min_v, max_v, v):
+        return max(min_v, min(max_v, v))
+
+    def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
         gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
 
-        lrcStroke = [[min_max_cut(0,height-1,pt_from[0]),
-                      min_max_cut(0,width-1,pt_from[1])],
-                     [min_max_cut(0,height-1,pt_via[0]),
-                      min_max_cut(0,width-1,pt_via[1])],
-                     [min_max_cut(0,height-1,pt_to[0]),
-                      min_max_cut(0,width-1,pt_to[1])]]
+        lrcStroke = [[min_max_cut(0, height - 1, pt_from[0]),
+                      min_max_cut(0, width - 1, pt_from[1])],
+                     [min_max_cut(0, height - 1, pt_via[0]),
+                      min_max_cut(0, width - 1, pt_via[1])],
+                     [min_max_cut(0, height - 1, pt_to[0]),
+                      min_max_cut(0, width - 1, pt_to[1])]]
 
         rc3Cells = np.array(lrcStroke[:3])  # the 3 cells
         rcMiddle = rc3Cells[1]  # the middle cell which we will update
@@ -622,27 +623,27 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
             grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
                                     mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
 
-    def make_switch_w_e(width,height,grid_map, center):
+    def make_switch_w_e(width, height, grid_map, center):
         # e -> w
-        start = (center[0]+1, center[1]-1)
+        start = (center[0] + 1, center[1] - 1)
         via = (center[0], center[1] - 1)
         goal = (center[0], center[1])
-        add_rail(width,height,grid_map, start, via, goal)
-        start = (center[0], center[1]-1)
-        via = (center[0]+1, center[1]-1)
-        goal = (center[0]+1, center[1]-2)
-        add_rail(width,height,grid_map, start, via, goal)
+        add_rail(width, height, grid_map, start, via, goal)
+        start = (center[0], center[1] - 1)
+        via = (center[0] + 1, center[1] - 1)
+        goal = (center[0] + 1, center[1] - 2)
+        add_rail(width, height, grid_map, start, via, goal)
 
-    def make_switch_e_w(width,height,grid_map, center):
+    def make_switch_e_w(width, height, grid_map, center):
         # e -> w
         start = (center[0] + 1, center[1])
         via = (center[0] + 1, center[1] - 1)
         goal = (center[0], center[1] - 1)
-        add_rail(width,height,grid_map, start, via, goal)
+        add_rail(width, height, grid_map, start, via, goal)
         start = (center[0] + 1, center[1] - 1)
         via = (center[0], center[1] - 1)
         goal = (center[0], center[1] - 2)
-        add_rail(width,height,grid_map, start, via, goal)
+        add_rail(width, height, grid_map, start, via, goal)
 
     class Grid4TransitionsEnum(IntEnum):
         NORTH = 0
@@ -669,10 +670,8 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
 
         np.random.seed(seed + num_resets)
 
-
-
-        max_n_track_seg=4
-        x_offsets = np.arange(0,height,max_n_track_seg).astype(int)
+        max_n_track_seg = np.random.choice([3, 4, 5])
+        x_offsets = np.arange(0, height, max_n_track_seg).astype(int)
 
         agents_position = []
         agents_target = []
@@ -681,7 +680,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
         for off_set_loop in range(len(x_offsets)):
             off_set = x_offsets[off_set_loop]
             # second track
-            data = np.arange(int((width -4- max_n_track_seg) / max_n_track_seg)) * max_n_track_seg + 4
+            data = np.arange(int((width - 4 - max_n_track_seg) / max_n_track_seg)) * max_n_track_seg + 4
             n_track_seg = np.random.choice(max_n_track_seg) + 1
 
             start_track = (off_set, 0)
@@ -691,86 +690,73 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
             # track one (full track : left right)
             if off_set_loop > 0:
                 if off_set_loop % 2 == 1:
-                    start_track = (x_offsets[off_set_loop-1]+1, width - 1)
-                    goal_track = (x_offsets[off_set_loop]-1, width - 1)
+                    start_track = (x_offsets[off_set_loop - 1] + 1, width - 1)
+                    goal_track = (x_offsets[off_set_loop] - 1, width - 1)
                     new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
 
-                    add_rail(width,height,grid_map,
-                             (x_offsets[off_set_loop-1], width - 2),
-                             (x_offsets[off_set_loop-1], width - 1),
-                             (x_offsets[off_set_loop-1]+1, width - 1))
-                    add_rail(width,height,grid_map,
+                    add_rail(width, height, grid_map,
+                             (x_offsets[off_set_loop - 1], width - 2),
+                             (x_offsets[off_set_loop - 1], width - 1),
+                             (x_offsets[off_set_loop - 1] + 1, width - 1))
+                    add_rail(width, height, grid_map,
                              (x_offsets[off_set_loop], width - 2),
                              (x_offsets[off_set_loop], width - 1),
-                             (x_offsets[off_set_loop]-1, width - 1))
-                    add_rail(width,height,grid_map,
-                             (x_offsets[off_set_loop-1], width - 1),
-                             (x_offsets[off_set_loop-1]+1, width - 1),
-                             (x_offsets[off_set_loop-1]+2, width - 1))
-                    add_rail(width,height,grid_map,
+                             (x_offsets[off_set_loop] - 1, width - 1))
+                    add_rail(width, height, grid_map,
+                             (x_offsets[off_set_loop - 1], width - 1),
+                             (x_offsets[off_set_loop - 1] + 1, width - 1),
+                             (x_offsets[off_set_loop - 1] + 2, width - 1))
+                    add_rail(width, height, grid_map,
                              (x_offsets[off_set_loop], width - 1),
-                             (x_offsets[off_set_loop]-1, width - 1),
-                             (x_offsets[off_set_loop]-2, width - 1))
+                             (x_offsets[off_set_loop] - 1, width - 1),
+                             (x_offsets[off_set_loop] - 2, width - 1))
 
                 else:
-                    start_track = (x_offsets[off_set_loop-1]+1,0)
-                    goal_track = (x_offsets[off_set_loop]-1, 0)
+                    start_track = (x_offsets[off_set_loop - 1] + 1, 0)
+                    goal_track = (x_offsets[off_set_loop] - 1, 0)
                     new_path = connect_rail(rail_trans, rail_array, start_track, goal_track)
 
-                    add_rail(width,height,grid_map,
-                             (x_offsets[off_set_loop-1], 1),
-                             (x_offsets[off_set_loop-1], 0),
-                             (x_offsets[off_set_loop-1]+1, 0))
-                    add_rail(width,height,grid_map,
+                    add_rail(width, height, grid_map,
+                             (x_offsets[off_set_loop - 1], 1),
+                             (x_offsets[off_set_loop - 1], 0),
+                             (x_offsets[off_set_loop - 1] + 1, 0))
+                    add_rail(width, height, grid_map,
                              (x_offsets[off_set_loop], 1),
                              (x_offsets[off_set_loop], 0),
-                             (x_offsets[off_set_loop]-1, 0))
-                    add_rail(width,height,grid_map,
-                             (x_offsets[off_set_loop-1], 0),
-                             (x_offsets[off_set_loop-1]+1, 0),
-                             (x_offsets[off_set_loop-1]+2, 0))
-                    add_rail(width,height,grid_map,
+                             (x_offsets[off_set_loop] - 1, 0))
+                    add_rail(width, height, grid_map,
+                             (x_offsets[off_set_loop - 1], 0),
+                             (x_offsets[off_set_loop - 1] + 1, 0),
+                             (x_offsets[off_set_loop - 1] + 2, 0))
+                    add_rail(width, height, grid_map,
                              (x_offsets[off_set_loop], 0),
-                             (x_offsets[off_set_loop]-1, 0),
-                             (x_offsets[off_set_loop]-2, 0))
+                             (x_offsets[off_set_loop] - 1, 0),
+                             (x_offsets[off_set_loop] - 2, 0))
 
-            for nbr_track_loop in range(height-1):
-                if len(data) < 2*n_track_seg+1:
+            for nbr_track_loop in range(height - 1):
+                if len(data) < 2 * n_track_seg + 1:
                     break
                 x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int)
                 data = []
                 for x_loop in range(int(len(x) / 2)):
-                    start = (max(0,min(off_set+nbr_track_loop+1,height-1)), max(0,min(x[2 * x_loop],width-1)))
-                    goal = (max(0,min(off_set+nbr_track_loop+1,height-1)), max(0,min(x[2 * x_loop + 1],width-1)))
-                    d = np.arange(x[2 * x_loop]+1,x[2 * x_loop+1]-1,2)
+                    start = (max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop], width - 1)))
+                    goal = (max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop + 1], width - 1)))
+                    d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2)
                     data.extend(d)
 
                     new_path = connect_rail(rail_trans, rail_array, start, goal)
-                    if len(new_path) >0:
-                        c = (off_set+nbr_track_loop, x[2 * x_loop] + 1)
-                        make_switch_e_w(width,height,grid_map, c)
-                        c = (off_set+nbr_track_loop, x[2 * x_loop + 1]+1)
-                        make_switch_w_e(width,height,grid_map, c)
-
-
-                    add_pos = (int((start[0]+goal[0])/2),int((start[1]+goal[1])/2))
-                    agents_position.append(add_pos)
-                    agents_target.append(add_pos)
-                    agents_direction.append(np.random.choice([3,1]))
-
-        print(agents_direction)
-        print(agents_position)
-        print(agents_target)
-
-        x = np.arange(len(agents_direction))
-        num_a = min(num_agents,np.floor(len(agents_direction)/2))
-        if num_a > 1:
-            filter_agent = np.random.choice(x,num_a,False)
-            agents_position = agents_position[filter_agent]
-            agents_direction = agents_direction[filter_agent]
-            np.delete(x,filter_agent)
-            filter_agent = np.random.choice(x,num_a,False)
-            agents_target = agents_target[filter_agent]
+                    if len(new_path) > 0:
+                        c = (off_set + nbr_track_loop, x[2 * x_loop] + 1)
+                        make_switch_e_w(width, height, grid_map, c)
+                        c = (off_set + nbr_track_loop, x[2 * x_loop + 1] + 1)
+                        make_switch_w_e(width, height, grid_map, c)
+
+                    if nbr_track_loop > 0:
+                        add_pos = (int((start[0] + goal[0]) / 2), int((start[1] + goal[1]) / 2))
+                        agents_position.append(add_pos)
+                        agents_target.append(add_pos)
+                        agents_direction.append(np.random.choice([3, 1]))
+
 
         return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
@@ -838,8 +824,6 @@ def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_
                 new_path = connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
             node_stack.pop(0)
 
-
-
         # Generate start and target node directory for all agents
         agent_start_targets_nodes = []
         for agent_idx in range(num_agents):
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 1e86026cf8e978d872a715c5ccc1abff357e60ac..7079a4f9e3b176aaad5398d128eea63c7f8adccd 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -5,19 +5,21 @@ from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
+import numpy as np
 
 def test_realistic_rail_generator():
-
-    env = RailEnv(width=40,
-                  height=16,
-                  rail_generator=realistic_rail_generator(),
-                  number_of_agents=15,
-                  obs_builder_object=GlobalObsForRailEnv())
-    # reset to initialize agents_static
-    env_renderer = RenderTool(env, gl="PILSVG", )
-    env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
-    time.sleep(10)
-
+    for test_loop in range(20):
+        num_agents = np.random.randint(10,30)
+        env = RailEnv(width=np.random.randint(40,80),
+                      height=np.random.randint(10,20),
+                      rail_generator=realistic_rail_generator(nr_start_goal=num_agents+1,seed=test_loop),
+                      number_of_agents=num_agents,
+                      obs_builder_object=GlobalObsForRailEnv())
+        # reset to initialize agents_static
+        env_renderer = RenderTool(env, gl="PILSVG", )
+        env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+        time.sleep(1)
+        env_renderer.close_window()
 
 def test_sparse_rail_generator():