diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index ab2d02912bf61b796b5c7ced312918b7cf56615c..6d1205b3b505f77dd57cd670e669c93401200b5d 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,3 +1,4 @@
+import warnings
 from enum import IntEnum
 
 import msgpack
@@ -8,7 +9,7 @@ from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.grid4_generators_utils import connect_rail
+from flatland.envs.grid4_generators_utils import connect_rail, connect_from_nodes, connect_nodes, connect_to_nodes
 from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
 
@@ -692,7 +693,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
     return generator
 
 
-def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbours=2, min_node_dist=20, node_radius=2,
+def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2,
                           seed=0):
     '''
 
@@ -708,7 +709,7 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbour
 
         if num_agents > nr_train_stations:
             num_agents = nr_train_stations
-            print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
+            warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
@@ -719,21 +720,52 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbour
         node_positions = []
         for node_idx in range(nr_nodes):
             to_close = True
+            tries = 0
             while to_close:
-                x_tmp = np.random.randint(width)
-                y_tmp = np.random.randint(height)
+                x_tmp = np.random.randint(height)
+                y_tmp = np.random.randint(width)
                 to_close = False
                 for node_pos in node_positions:
                     if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
                         to_close = True
                 if not to_close:
                     node_positions.append((x_tmp, y_tmp))
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set nodes, please change initial parameters!!!!")
+                    break
+
+        # Chose node connection
+        available_nodes = np.arange(nr_nodes)
+        current_node = 0
+        node_stack = [current_node]
+
+        while len(node_stack) > 0:
+            current_node = node_stack[0]
+            delete_idx = np.where(available_nodes == current_node)
+            available_nodes = np.delete(available_nodes, delete_idx, 0)
+
+            # Get random number of neighbors
+            num_neighb = 2  # np.random.randint(1, max_neigbours)
+            if len(available_nodes) >= num_neighb:
+                connected_neighb_idx = np.random.choice(available_nodes, num_neighb, replace=False)
+            else:
+                connected_neighb_idx = available_nodes
+
+            for neighb in connected_neighb_idx:
+                if neighb not in node_stack:
+                    node_stack.append(neighb)
+                new_path = connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
+            node_stack.pop(0)
+
+
 
         # Generate start and target node directory for all agents
         agent_start_targets_nodes = []
         for agent_idx in range(num_agents):
             start_target_tuple = np.random.choice(nr_nodes, 2, replace=False)
             agent_start_targets_nodes.append(start_target_tuple)
+
         # Generate actual start and target locations from around nodes
         agents_position = []
         agents_target = []
@@ -741,42 +773,39 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbour
         agent_idx = 0
         for start_target in agent_start_targets_nodes:
             start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0,
-                              width - 1)
-            start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0,
                               height - 1)
+            start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0,
+                              width - 1)
             target_x = np.clip(node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
-                               width - 1)
-            target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
                                height - 1)
+            target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
+                               width - 1)
             if agent_idx == 0:
-                agents_position.append((start_y, start_x))
-                agents_target.append((target_y, target_x))
+                agents_position.append((start_x, start_y))
+                agents_target.append((target_x, target_y))
             else:
-                while ((start_x, start_y) in agents_position or (target_x, target_y) in agents_target):
+                # Make sure we don't put to starts or targets on same cell
+                while (start_x, start_y) in agents_position or (target_x, target_y) in agents_target:
                     start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius),
                                       0,
-                                      width - 1)
+                                      height - 1)
                     start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius),
                                       0,
-                                      height - 1)
+                                      width - 1)
                     target_x = np.clip(
                         node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
-                        width - 1)
+                        height - 1)
                     target_y = np.clip(
                         node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
-                        height - 1)
-                agents_position.append((start_y, start_x))
-                agents_target.append((target_y, target_x))
-
+                        width - 1)
+                agents_position.append((start_x, start_y))
+                agents_target.append((target_x, target_y))
+            new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx],
+                                        node_positions[start_target[0]])
+            new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]],
+                                          agents_target[agent_idx])
             agents_direction.append(0)
             agent_idx += 1
-
-        print(agents_position)
-        print(agents_target)
-        print(node_positions)
-        for n in node_positions:
-            for m in node_positions:
-                print(distance_on_rail(n, m))
         return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index dedd76b6bfd04c13ad59092adfefbde6ae98fc18..9116adb6639fc699d00b45b58241a4bdcdfe9c74 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -57,6 +57,143 @@ def connect_rail(rail_trans, rail_array, start, end):
     return path
 
 
+def connect_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+
+        current_dir = new_dir
+    return path
+
+
+def connect_from_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+
+        current_dir = new_dir
+    return path
+
+
+def connect_to_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+
+        current_dir = new_dir
+    return path
+
 def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
     """
     Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index 8934a4022611364f0adc4df806376fd43e5bc8f4..b64aaa640d27f24b1ab2bd87e30a29784f7787e7 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -1,14 +1,20 @@
+import time
+
 from flatland.envs.generators import sparse_rail_generator
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
 
 
 def test_sparse_rail_generator():
+
     env = RailEnv(width=20,
                   height=20,
-                  rail_generator=sparse_rail_generator(nr_train_stations=10, nr_nodes=5, min_node_dist=10,
+                  rail_generator=sparse_rail_generator(nr_train_stations=3, nr_nodes=2, min_node_dist=5,
                                                        node_radius=4),
-                  number_of_agents=10,
+                  number_of_agents=3,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
-    env.reset()
+    env_renderer = RenderTool(env, gl="PILSVG", )
+    env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+    time.sleep(10)