diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 652671ae48b5cd269a20e79a31657c2e7434b8d4..ab2d02912bf61b796b5c7ced312918b7cf56615c 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -690,3 +690,93 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
         return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
+
+
+def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbours=2, min_node_dist=20, node_radius=2,
+                          seed=0):
+    '''
+
+    :param nr_train_stations:
+    :param nr_nodes:
+    :param mean_node_neighbours:
+    :param min_node_dist:
+    :param seed:
+    :return:
+    '''
+
+    def generator(width, height, num_agents, num_resets=0):
+
+        if num_agents > nr_train_stations:
+            num_agents = nr_train_stations
+            print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
+        rail_trans = RailEnvTransitions()
+        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        rail_array = grid_map.grid
+        rail_array.fill(0)
+        np.random.seed(seed + num_resets)
+
+        # Generate a set of nodes for the sparse network
+        node_positions = []
+        for node_idx in range(nr_nodes):
+            to_close = True
+            while to_close:
+                x_tmp = np.random.randint(width)
+                y_tmp = np.random.randint(height)
+                to_close = False
+                for node_pos in node_positions:
+                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                        to_close = True
+                if not to_close:
+                    node_positions.append((x_tmp, y_tmp))
+
+        # Generate start and target node directory for all agents
+        agent_start_targets_nodes = []
+        for agent_idx in range(num_agents):
+            start_target_tuple = np.random.choice(nr_nodes, 2, replace=False)
+            agent_start_targets_nodes.append(start_target_tuple)
+        # Generate actual start and target locations from around nodes
+        agents_position = []
+        agents_target = []
+        agents_direction = []
+        agent_idx = 0
+        for start_target in agent_start_targets_nodes:
+            start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0,
+                              width - 1)
+            start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0,
+                              height - 1)
+            target_x = np.clip(node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
+                               width - 1)
+            target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
+                               height - 1)
+            if agent_idx == 0:
+                agents_position.append((start_y, start_x))
+                agents_target.append((target_y, target_x))
+            else:
+                while ((start_x, start_y) in agents_position or (target_x, target_y) in agents_target):
+                    start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius),
+                                      0,
+                                      width - 1)
+                    start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius),
+                                      0,
+                                      height - 1)
+                    target_x = np.clip(
+                        node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
+                        width - 1)
+                    target_y = np.clip(
+                        node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
+                        height - 1)
+                agents_position.append((start_y, start_x))
+                agents_target.append((target_y, target_x))
+
+            agents_direction.append(0)
+            agent_idx += 1
+
+        print(agents_position)
+        print(agents_target)
+        print(node_positions)
+        for n in node_positions:
+            for m in node_positions:
+                print(distance_on_rail(n, m))
+        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+
+    return generator
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..8934a4022611364f0adc4df806376fd43e5bc8f4
--- /dev/null
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -0,0 +1,14 @@
+from flatland.envs.generators import sparse_rail_generator
+from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.rail_env import RailEnv
+
+
+def test_sparse_rail_generator():
+    env = RailEnv(width=20,
+                  height=20,
+                  rail_generator=sparse_rail_generator(nr_train_stations=10, nr_nodes=5, min_node_dist=10,
+                                                       node_radius=4),
+                  number_of_agents=10,
+                  obs_builder_object=GlobalObsForRailEnv())
+    # reset to initialize agents_static
+    env.reset()