diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index be21f86cba60eb6b4d3b3ad56e32202ad002a9b7..2441e3e7652828d9c83a4ca5860e030f7921e8a8 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -673,10 +673,16 @@ def realistic_rail_generator(nr_start_goal=1,  seed=0):
         max_n_track_seg = np.random.choice([3, 4, 5])
         x_offsets = np.arange(0, height, max_n_track_seg).astype(int)
 
-        agents_positions = []
+        agents_positions_forward = []
+        agents_directions_forward = []
+        agents_positions_backward = []
+        agents_directions_backward = []
         agents_targets = []
-        agents_directions = []
 
+        idx_forward = []
+        idx_backward = []
+
+        idx_target=0
         for off_set_loop in range(len(x_offsets)):
             off_set = x_offsets[off_set_loop]
             # second track
@@ -733,7 +739,7 @@ def realistic_rail_generator(nr_start_goal=1,  seed=0):
                              (x_offsets[off_set_loop] - 1, 0),
                              (x_offsets[off_set_loop] - 2, 0))
 
-            for nbr_track_loop in range(height - 1):
+            for nbr_track_loop in range(max_n_track_seg-1):
                 if len(data) < 2 * n_track_seg + 1:
                     break
                 x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int)
@@ -752,24 +758,52 @@ def realistic_rail_generator(nr_start_goal=1,  seed=0):
                         make_switch_w_e(width, height, grid_map, c)
 
                     add_pos = (int((start[0] + goal[0]) / 2), int((start[1] + goal[1]) / 2))
-                    agents_positions.append(add_pos)
-                    agents_directions.append(3)
-                    add_pos = (int((start[0] + goal[0]) / 2), int((2*start[1] + goal[1]) / 3))
+                    if nbr_track_loop % 2 == 0:
+                        agents_positions_forward.append(add_pos)
+                        agents_directions_forward.append(([1, 3][off_set_loop % 2]))
+                        idx_forward.append(idx_target)
+                    else:
+                        agents_positions_backward.append(add_pos)
+                        agents_directions_backward.append(([1, 3][off_set_loop % 2]))
+                        idx_backward.append(idx_target)
+
+                    add_pos = (int((start[0] + goal[0]) / 2), int((2*start[1] + goal[1]) / 3),idx_target)
                     agents_targets.append(add_pos)
+                    idx_target+=1
 
         agents_position = []
         agents_target = []
         agents_direction = []
-        filter_agent = np.random.choice(np.arange(len(agents_positions)),min(len(agents_positions),num_agents),False)
-        for f in filter_agent:
-            d = agents_positions[f]
-            agents_position.append(d)
-            d = agents_directions[f]
-            agents_direction.append(d)
-        filter_target = np.random.choice(np.arange(len(agents_targets)),min(len(agents_targets),num_agents),False)
-        for f in filter_target:
-            d = agents_targets[f]
-            agents_target.append(d)
+
+        for a in range(min(len(agents_targets),num_agents)):
+            t = np.random.choice(range(len(agents_targets)))
+            d = agents_targets[t]
+            agents_targets.pop(t)
+            if d[2] < idx_target / 2:
+                if len(idx_backward) > 0:
+                    agents_target.append((d[0], d[1]))
+                    sel = np.random.choice(range(len(idx_backward)))
+                    # backward
+                    p = agents_positions_backward[sel]
+                    d = agents_directions_backward[sel]
+                    agents_positions_backward.pop(sel)
+                    agents_directions_backward.pop(sel)
+                    idx_backward.pop(sel)
+                    agents_position.append((p[0],p[1]))
+                    agents_direction.append(d)
+            else:
+                if len(idx_forward) > 0:
+                    agents_target.append((d[0], d[1]))
+                    sel = np.random.choice(range(len(idx_forward)))
+                    # forward
+                    p = agents_positions_forward[sel]
+                    d = agents_directions_forward[sel]
+                    agents_positions_forward.pop(sel)
+                    agents_directions_forward.pop(sel)
+                    idx_forward.pop(sel)
+                    agents_position.append((p[0],p[1]))
+                    agents_direction.append(d)
+
 
         return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 2a361ca99c4b83da933abe674428afc15347e4db..3a08f5b2f7cdf9b7103194348266d997be3dda04 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -9,7 +9,7 @@ from flatland.utils.rendertools import RenderTool
 
 
 def test_realistic_rail_generator():
-    for test_loop in range(5):
+    for test_loop in range(20):
         num_agents = np.random.randint(10,30)
         env = RailEnv(width=np.random.randint(40,80),
                       height=np.random.randint(10,20),
@@ -23,6 +23,18 @@ def test_realistic_rail_generator():
         env_renderer.close_window()
 
 def test_sparse_rail_generator():
+
+    env = RailEnv(width=20,
+                  height=20,
+                  rail_generator=sparse_rail_generator(nr_nodes=3, min_node_dist=8,
+                                                       node_radius=4),
+                  number_of_agents=15,
+
+    env = RailEnv(width=20,
+                  height=20,
+                  rail_generator=sparse_rail_generator(nr_nodes=3, min_node_dist=8,
+                                                       node_radius=4),
+                  number_of_agents=15,
     env = RailEnv(width=50,
                   height=50,
                   rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
@@ -38,5 +50,7 @@ def test_sparse_rail_generator():
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
     env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+    time.sleep(2)
+
     env_renderer.gl.save_image("flatalnd_2_0.png")
     time.sleep(100)