+import numpy as np
+from flatland.envs.generators import sparse_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
+# Training on simple small tasks is the best way to get familiar with the environment
+# Use a the malfunction generator to break agents from time to time
+stochastic_data = {'prop_malfunction': 0.5,  # Percentage of defective agents
+                   'malfunction_rate': 30,  # Rate of malfunction occurence
+                   'min_duration': 3,  # Minimal duration of malfunction
+                   'max_duration': 10  # Max duration of malfunction
+                   }
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map (where train stations are)
+                                                   num_intersections=1,  # Number of intersections (no start / target)
+                                                   num_trainstations=15,  # Number of possible start/targets on map
+                                                   min_node_dist=3,  # Minimal distance of nodes
+                                                   node_radius=3,  # Proximity of stations to city center
+                                                   num_neighb=2,  # Number of connections to other cities/intersections
+                                                   seed=15,  # Random seed
+                                                   realistic_mode=True,
+                                                   enhance_intersection=True
+                                                   ),
+              number_of_agents=5,
+              stochastic_data=stochastic_data,  # Malfunction data generator
+              obs_builder_object=TreeObservation)
+env_renderer = RenderTool(env, gl="PILSVG", )
+# Import your own Agent or use RLlib to train agents on Flatland
+# As an example we use a random agent instead
+class RandomAgent:
+    def __init__(self, state_size, action_size):
+        self.state_size = state_size
+        self.action_size = action_size
+    def act(self, state):
+        """
+        :param state: input is the observation of the agent
+        :return: returns an action
+        """
+        return np.random.choice(np.arange(self.action_size))
+    def step(self, memories):
+        """
+        Step function to improve agent by adjusting policy given the observations
+        :param memories: SARS Tuple to be
+        :return:
+        """
+        return
+    def save(self, filename):
+        # Store the current policy
+        return
+    def load(self, filename):
+        # Load a policy
+        return
+# Initialize the agent with the parameters corresponding to the environment and observation_builder
+# Set action space to 4 to remove stop action
+agent = RandomAgent(218, 4)
+# Empty dictionary for all agent action
+action_dict = dict()
+print("Start episode...")
+# Reset environment and get initial observations for all agents
+obs = env.reset()
+# Update/Set agent's speed
+for idx in range(env.get_num_agents()):
+    speed = 1.0 / ((idx % 5) + 1.0)
+    env.agents[idx].speed_data["speed"] = speed
+# Reset the rendering sytem
+# Here you can also further enhance the provided observation by means of normalization
+# See training navigation example in the baseline repository
+score = 0
+# Run episode
+frame_step = 0
+for step in range(500):
+    # Chose an action for each agent in the environment
+    for a in range(env.get_num_agents()):
+        action = agent.act(obs[a])
+        action_dict.update({a: action})
+    # Environment step which returns the observations for all agents, their corresponding
+    # reward and whether their are done
+    next_obs, all_rewards, done, _ = env.step(action_dict)
+    env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
+    frame_step += 1
+    # Update replay buffer and train agent
+    for a in range(env.get_num_agents()):
+        agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
+        score += all_rewards[a]
+    obs = next_obs.copy()
+    if done['__all__']:
+        break
+print('Episode: Steps {}\t Score = {}'.format(step, score))
@@ -350,4 +350,124 @@ class GridTransitionMap(TransitionMap):
         return True
+    def fix_neighbours(self, rcPos, check_this_cell=False):
+        """
+        Check validity of cell at rcPos = tuple(row, column)
+        Checks that:
+        - surrounding cells have inbound transitions for all the
+            outbound transitions of this cell.
+        These are NOT checked - see transition.is_valid:
+        - all transitions have the mirror transitions (N->E <=> W->S)
+        - Reverse transitions (N -> S) only exist for a dead-end
+        - a cell contains either no dead-ends or exactly one
+        Returns: True (valid) or False (invalid)
+        """
+        cell_transition = self.grid[tuple(rcPos)]
+        if check_this_cell:
+            if not self.transitions.is_valid(cell_transition):
+                return False
+        gDir2dRC = self.transitions.gDir2dRC  # [[-1,0] = N, [0,1]=E, etc]
+        grcPos = array(rcPos)
+        grcMax = self.grid.shape
+        binTrans = self.get_full_transitions(*rcPos)  # 16bit integer - all trans in/out
+        lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8)  # 2 x uint8
+        g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4)  # 4x4 x uint8 binary(0,1)
+        gDirOut = g2binTrans.any(axis=0)  # outbound directions as boolean array (4)
+        giDirOut = np.argwhere(gDirOut)[:, 0]  # valid outbound directions as array of int
+        # loop over available outbound directions (indices) for rcPos
+        for iDirOut in giDirOut:
+            gdRC = gDir2dRC[iDirOut]  # row,col increment
+            gPos2 = grcPos + gdRC  # next cell in that direction
+            # Check the adjacent cell is within bounds
+            # if not, then this transition is invalid!
+            if np.any(gPos2 < 0):
+                return False
+            if np.any(gPos2 >= grcMax):
+                return False
+            # Get the transitions out of gPos2, using iDirOut as the inbound direction
+            # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
+            t4Trans2 = self.get_transitions(*gPos2, iDirOut)
+            if any(t4Trans2):
+                continue
+            else:
+                self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
+                return False
+        return True
+    def fix_transitions(self, rcPos):
+        """
+        Fixes broken transitions
+        """
+        gDir2dRC = self.transitions.gDir2dRC  # [[-1,0] = N, [0,1]=E, etc]
+        grcPos = array(rcPos)
+        grcMax = self.grid.shape
+        # loop over available outbound directions (indices) for rcPos
+        self.set_transitions(rcPos, 0)
+        incoming_connections = np.zeros(4)
+        for iDirOut in np.arange(4):
+            gdRC = gDir2dRC[iDirOut]  # row,col increment
+            gPos2 = grcPos + gdRC  # next cell in that direction
+            # Check the adjacent cell is within bounds
+            # if not, then ignore it for the count of incoming connections
+            if np.any(gPos2 < 0):
+                continue
+            if np.any(gPos2 >= grcMax):
+                continue
+            # Get the transitions out of gPos2, using iDirOut as the inbound direction
+            # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
+            connected = 0
+            for orientation in range(4):
+                connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
+            if connected > 0:
+                incoming_connections[iDirOut] = 1
+        number_of_incoming = np.sum(incoming_connections)
+        # Only one incoming direction --> Straight line
+        if number_of_incoming == 1:
+            for direction in range(4):
+                if incoming_connections[direction] > 0:
+                    self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
+        # Connect all incoming connections
+        if number_of_incoming == 2:
+            connect_directions = np.argwhere(incoming_connections > 0)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
+        # Find feasible connection fro three entries
+        if number_of_incoming == 3:
+            hole = np.argwhere(incoming_connections < 1)[0][0]
+            connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1)
+        # Make a cross
+        if number_of_incoming == 4:
+            connect_directions = np.arange(4)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1)
+        return True
+def mirror(dir):
+    return (dir + 2) % 4
 # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
@@ -53,3 +53,221 @@ def connect_rail(rail_trans, rail_array, start, end):
         current_dir = new_dir
     return path
+def connect_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # don't set any transition at node yet
+                new_trans = 0
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                # don't set any transition at node yet
+                new_trans_e = 0
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+        current_dir = new_dir
+    return path
+def connect_from_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = 0
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+        current_dir = new_dir
+    return path
+def connect_to_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = 0
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+        current_dir = new_dir
+    return path
+def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
+    """
+    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
+    TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
+    """
+    def _path_exists(rail, start, direction, end):
+        # BFS - Check if a path exists between the 2 nodes
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            if node[0][0] == end[0] and node[0][1] == end[1]:
+                return 1
+            if node not in visited:
+                visited.add(node)
+                moves = rail.get_transitions(node[0][0], node[0][1], node[1])
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((get_new_position(node[0], move_index),
+                                      move_index))
+                # If cell is a dead-end, append previous node with reversed
+                # orientation!
+                nbits = 0
+                tmp = rail.get_full_transitions(node[0][0], node[0][1])
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    stack.append((node[0], (node[1] + 2) % 4))
+        return 0
+    valid_positions = []
+    for r in range(rail.height):
+        for c in range(rail.width):
+            if rail.get_full_transitions(r, c) > 0:
+                valid_positions.append((r, c))
+    re_generate = True
+    while re_generate:
+        agents_position = [
+            valid_positions[i] for i in
+            np.random.choice(len(valid_positions), num_agents)]
+        agents_target = [
+            valid_positions[i] for i in
+            np.random.choice(len(valid_positions), num_agents)]
+        # agents_direction must be a direction for which a solution is
+        # guaranteed.
+        agents_direction = [0] * num_agents
+        re_generate = False
+        for i in range(num_agents):
+            valid_movements = []
+            for direction in range(4):
+                position = agents_position[i]
+                moves = rail.get_transitions(position[0], position[1], direction)
+                for move_index in range(4):
+                    if moves[move_index]:
+                        valid_movements.append((direction, move_index))
+            valid_starting_directions = []
+            for m in valid_movements:
+                new_position = get_new_position(agents_position[i], m[1])
+                if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
+                    valid_starting_directions.append(m[0])
+            if len(valid_starting_directions) == 0:
+                re_generate = True
+            else:
+                agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
+    return agents_position, agents_direction, agents_target
@@ -2,7 +2,7 @@
 Definition of the RailEnv environment.
 # TODO:  _ this is a global method --> utils or remove later
+import warnings
 from enum import IntEnum
 import msgpack
@@ -11,6 +11,7 @@ import numpy as np
 from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
@@ -134,7 +135,8 @@ class RailEnv(Environment):
         self.rail_generator: RailGenerator = rail_generator
         self.agent_generator: ScheduleGenerator = agent_generator
-        self.rail = None
+        self.rail_generator = rail_generator
+        self.rail: GridTransitionMap = None
         self.width = width
         self.height = height
@@ -223,6 +225,12 @@ class RailEnv(Environment):
         if regen_rail or self.rail is None:
             self.rail = rail
             self.height, self.width = self.rail.grid.shape
+            for r in range(self.height):
+                for c in range(self.width):
+                    rcPos = (r, c)
+                    check = self.rail.cell_neighbours_valid(rcPos, True)
+                    if not check:
+                        warnings.warn("Invalid grid at {} -> {}".format(rcPos, check))
         if replace_agents:
             agents_hints = None
 """Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
+import warnings
 from typing import Callable, Tuple, Any, Optional
 import msgpack
@@ -8,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail
+from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
@@ -523,3 +524,310 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
         return return_rail, None
     return generator
+def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2,
+                          num_neighb=3, realistic_mode=False, enhance_intersection=False, seed=0):
+    """
+    This is a level generator which generates complex sparse rail configurations
+    :param num_cities: Number of city node (can hold trainstations)
+    :param num_intersections: Number of intersection that city nodes can connect to
+    :param num_trainstations: Total number of trainstations in env
+    :param min_node_dist: Minimal distance between nodes
+    :param node_radius: Proximity of trainstations to center of city node
+    :param num_neighb: Number of neighbouring nodes each node connects to
+    :param realistic_mode: True -> NOdes evenly distirbuted in env, False-> Random distribution of nodes
+    :param enhance_intersection: True -> Extra rail elements added at intersections
+    :param seed: Random Seed
+    :return:
+        -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+    def generator(width, height, num_agents, num_resets=0):
+        if num_agents > num_trainstations:
+            num_agents = num_trainstations
+            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
+        rail_trans = RailEnvTransitions()
+        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        rail_array = grid_map.grid
+        rail_array.fill(0)
+        np.random.seed(seed + num_resets)
+        # Generate a set of nodes for the sparse network
+        # Try to connect cities to nodes first
+        node_positions = []
+        city_positions = []
+        intersection_positions = []
+        # Evenly distribute cities and intersections
+        if realistic_mode:
+            tot_num_node = num_intersections + num_cities
+            nodes_ratio = height / width
+            nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio)))
+            nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
+            x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
+            y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
+        for node_idx in range(num_cities + num_intersections):
+            to_close = True
+            tries = 0
+            if not realistic_mode:
+                while to_close:
+                    x_tmp = node_radius + np.random.randint(height - node_radius)
+                    y_tmp = node_radius + np.random.randint(width - node_radius)
+                    to_close = False
+                    # Check distance to cities
+                    for node_pos in city_positions:
+                        if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                            to_close = True
+                    # CHeck distance to intersections
+                    for node_pos in intersection_positions:
+                        if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                            to_close = True
+                    if not to_close:
+                        node_positions.append((x_tmp, y_tmp))
+                        if node_idx < num_cities:
+                            city_positions.append((x_tmp, y_tmp))
+                        else:
+                            intersection_positions.append((x_tmp, y_tmp))
+                    tries += 1
+                    if tries > 100:
+                        warnings.warn("Could not set nodes, please change initial parameters!!!!")
+                        break
+            else:
+                x_tmp = x_positions[node_idx % nodes_per_row]
+                y_tmp = y_positions[node_idx // nodes_per_row]
+                if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0:
+                    city_positions.append((x_tmp, y_tmp))
+                else:
+                    intersection_positions.append((x_tmp, y_tmp))
+        node_positions = city_positions + intersection_positions
+        # Chose node connection
+        # Set up list of available nodes to connect to
+        available_nodes_full = np.arange(num_cities + num_intersections)
+        available_cities = np.arange(num_cities)
+        available_intersections = np.arange(num_cities, num_cities + num_intersections)
+        # Start at some node
+        current_node = np.random.randint(len(available_nodes_full))
+        node_stack = [current_node]
+        allowed_connections = num_neighb
+        first_node = True
+        while len(node_stack) > 0:
+            current_node = node_stack[0]
+            delete_idx = np.where(available_nodes_full == current_node)
+            available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
+            # Priority city to intersection connections
+            if current_node < num_cities and len(available_intersections) > 0:
+                available_nodes = available_intersections
+                delete_idx = np.where(available_cities == current_node)
+                available_cities = np.delete(available_cities, delete_idx, 0)
+            # Priority intersection to city connections
+            elif current_node >= num_cities and len(available_cities) > 0:
+                available_nodes = available_cities
+                delete_idx = np.where(available_intersections == current_node)
+                available_intersections = np.delete(available_intersections, delete_idx, 0)
+            # If no options possible connect to whatever node is still available
+            else:
+                available_nodes = available_nodes_full
+            # Sort available neighbors according to their distance.
+            node_dist = []
+            for av_node in available_nodes:
+                node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
+            available_nodes = available_nodes[np.argsort(node_dist)]
+            # Set number of neighboring nodes
+            if len(available_nodes) >= allowed_connections:
+                connected_neighb_idx = available_nodes[:allowed_connections]
+            else:
+                connected_neighb_idx = available_nodes
+            # Less connections for subsequent nodes
+            if first_node:
+                allowed_connections -= 1
+                first_node = False
+            # Connect to the neighboring nodes
+            for neighb in connected_neighb_idx:
+                if neighb not in node_stack:
+                    node_stack.append(neighb)
+                connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
+            node_stack.pop(0)
+        # Place train stations close to the node
+        # We currently place them uniformly distirbuted among all cities
+        if num_cities > 1:
+            train_stations = [[] for i in range(num_cities)]
+            built_num_trainstation = 0
+            spot_found = True
+            for station in range(num_trainstations):
+                trainstation_node = int(station / num_trainstations * num_cities)
+                station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
+                                    0,
+                                    height - 1)
+                station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
+                                    0,
+                                    width - 1)
+                tries = 0
+                while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
+                    trainstation_node] or rail_array[(station_x, station_y)] != 0:
+                    station_x = np.clip(
+                        node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
+                        0,
+                        height - 1)
+                    station_y = np.clip(
+                        node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
+                        0,
+                        width - 1)
+                    tries += 1
+                    if tries > 100:
+                        warnings.warn("Could not set trainstations, please change initial parameters!!!!")
+                        spot_found = False
+                        break
+                if spot_found:
+                    train_stations[trainstation_node].append((station_x, station_y))
+                # Connect train station to the correct node
+                connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
+                                                (station_x, station_y))
+                # Check if connection was made
+                if len(connection) == 0:
+                    train_stations[trainstation_node].pop(-1)
+                else:
+                    built_num_trainstation += 1
+        # Adjust the number of agents if you could not build enough trainstations
+        if num_agents > built_num_trainstation:
+            num_agents = built_num_trainstation
+            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
+        # Place passing lanes at intersections
+        # We currently place them uniformly distirbuted among all cities
+        if enhance_intersection:
+            for intersection in range(num_intersections):
+                intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3),
+                                        1,
+                                        height - 2)
+                intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3),
+                                        2,
+                                        width - 2)
+                intersect_x_2 = np.clip(
+                    intersection_positions[intersection][0] + np.random.randint(-3, -1),
+                    1,
+                    height - 2)
+                intersect_y_2 = np.clip(
+                    intersection_positions[intersection][1] + np.random.randint(-3, 3),
+                    1,
+                    width - 2)
+                # Connect train station to the correct node
+                connect_nodes(rail_trans, rail_array, (intersect_x_1, intersect_y_1),
+                              (intersect_x_2, intersect_y_2))
+                connect_nodes(rail_trans, rail_array, intersection_positions[intersection],
+                              (intersect_x_1, intersect_y_1))
+                connect_nodes(rail_trans, rail_array, intersection_positions[intersection],
+                              (intersect_x_2, intersect_y_2))
+                grid_map.fix_transitions((intersect_x_1, intersect_y_1))
+                grid_map.fix_transitions((intersect_x_2, intersect_y_2))
+        # Fix all nodes with illegal transition maps
+        for current_node in node_positions:
+            grid_map.fix_transitions(current_node)
+        # Generate start and target node directory for all agents.
+        # Assure that start and target are not in the same node
+        agent_start_targets_nodes = []
+        # Slot availability in node
+        node_available_start = []
+        node_available_target = []
+        for node_idx in range(num_cities):
+            node_available_start.append(len(train_stations[node_idx]))
+            node_available_target.append(len(train_stations[node_idx]))
+        # Assign agents to slots
+        for agent_idx in range(num_agents):
+            avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
+            avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
+            start_node = np.random.choice(avail_start_nodes)
+            target_node = np.random.choice(avail_target_nodes)
+            tries = 0
+            found_agent_pair = True
+            while target_node == start_node:
+                target_node = np.random.choice(avail_target_nodes)
+                tries += 1
+                # Test again with new start node if no pair is found (This code needs to be improved)
+                if (tries + 1) % 10 == 0:
+                    start_node = np.random.choice(avail_start_nodes)
+                if tries > 100:
+                    warnings.warn("Could not set trainstations, removing agent!")
+                    found_agent_pair = False
+                    break
+            if found_agent_pair:
+                node_available_start[start_node] -= 1
+                node_available_target[target_node] -= 1
+                agent_start_targets_nodes.append((start_node, target_node))
+            else:
+                num_agents -= 1
+        # Place agents and targets within available train stations
+        agents_position = []
+        agents_target = []
+        agents_direction = []
+        for agent_idx in range(num_agents):
+            # Set target for agent
+            current_target_node = agent_start_targets_nodes[agent_idx][1]
+            target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+            target = train_stations[current_target_node][target_station_idx]
+            tries = 0
+            while (target[0], target[1]) in agents_target:
+                target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+                target = train_stations[current_target_node][target_station_idx]
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set target position, removing an agent")
+                    break
+            agents_target.append((target[0], target[1]))
+            # Set start for agent
+            current_start_node = agent_start_targets_nodes[agent_idx][0]
+            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+            start = train_stations[current_start_node][start_station_idx]
+            tries = 0
+            while (start[0], start[1]) in agents_position:
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set start position, please change initial parameters!!!!")
+                    break
+                start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+                start = train_stations[current_start_node][start_station_idx]
+            agents_position.append((start[0], start[1]))
+            # Orient the agent correctly
+            for orientation in range(4):
+                transitions = grid_map.get_transitions(start[0], start[1], orientation)
+                if any(transitions) > 0:
+                    agents_direction.append(orientation)
+                    continue
+        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+    return generator
@@ -41,7 +41,7 @@ class PILGL(GraphicsLayer):
-    def __init__(self, width, height, jupyter=False):
+    def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
         self.yxBase = (0, 0)
         self.linewidth = 4
         self.n_agent_colors = 1  # overridden in loadAgent
@@ -52,13 +52,13 @@ class PILGL(GraphicsLayer):
         @@ -52,13 +52,13 @@ class PILGL(GraphicsLayer):
         if jupyter is False:
-            # NOTE: Currently removed the dependency on 
-            #       screeninfo. We have to find an alternate 
+            # NOTE: Currently removed the dependency on
+            #       screeninfo. We have to find an alternate
             #       way to compute the screen width and height
-            #       In the meantime, we are harcoding the 800x600 
+            #       In the meantime, we are harcoding the 800x600
             #       assumption
-            self.screen_width = 800
-            self.screen_height = 600
+            self.screen_width = screen_width
+            self.screen_height = screen_height
             w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth)
             h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth)
             self.nPixCell = int(max(1, np.ceil(min(w, h))))
@@ -116,7 +116,7 @@ class PILGL(GraphicsLayer):
                     for rc in dTargets:
                         r = rc[1]
                         c = rc[0]
-                        d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)))
+                        d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)) / 0.5)
                         distance = min(d, distance)
                     self.background_grid[x][y] = distance
@@ -271,9 +271,9 @@ class PILGL(GraphicsLayer):
@@ -271,9 +271,9 @@ class PILGL(GraphicsLayer):
+    def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
         oSuper = super()
-        oSuper.__init__(width, height, jupyter)
+        oSuper.__init__(width, height, jupyter, screen_width, screen_height)
         self.lwAgents = []
         self.agents_prev = []
@@ -452,7 +452,7 @@ class PILSVG(PILGL):
         @@ -452,7 +452,7 @@ class PILSVG(PILGL):
-            # Translate the ascii transition description in the format  "NE WS" to the 
+            # Translate the ascii transition description in the format  "NE WS" to the
             # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
             transition_16_bit = ["0"] * 16
             for sTran in transition.split(" "):
@@ -40,8 +40,9 @@ class RenderTool(object):
     arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]
     def __init__(self, env, gl="PILSVG", jupyter=False,
-            agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
-            show_debug=True):
+                 agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                 show_debug=False, screen_width=800, screen_height=600):
         self.env = env
         self.frame_nr = 0
         self.start_time = time.time()
@@ -50,12 +51,12 @@ class RenderTool(object):
         self.agent_render_variant = agent_render_variant
         if gl == "PIL":
-            self.gl = PILGL(env.width, env.height, jupyter)
+            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
         elif gl == "PILSVG":
-            self.gl = PILSVG(env.width, env.height, jupyter)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
             print("[", gl, "] not found, switch to PILSVG")
-            self.gl = PILSVG(env.width, env.height, jupyter)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
         self.new_rail = True
         self.show_debug = show_debug
@@ -554,7 +555,7 @@ class RenderTool(object):
                 if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
                     self.gl.set_cell_occupied(agent_idx, *(agent.position))
                 self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
-                    selected_agent == agent_idx, show_debug=self.show_debug)
+                                     selected_agent == agent_idx, show_debug=self.show_debug)
                 position = agent.position
                 direction = agent.direction
@@ -2,11 +2,87 @@ from typing import Tuple
 import numpy as np
-from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _ _  _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [horizontal_straight] * 2 + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _ _  _  _ _ _
+    #               \
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [horizontal_straight] * 2 + [simple_switch_west_east_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
     #        |
@@ -16,15 +92,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     #                |
     #                |
     #                |
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
     empty = cells[0]
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
 import numpy as np
-from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -12,15 +12,8 @@ from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail
 def test_walker():
     # _ _ _
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
     dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail
 from flatland.utils.rendertools import RenderTool
-from flatland.utils.simple_rail import make_simple_rail
+from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
 """Test predictions for `flatland` package."""
 def test_dummy_predictor(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map = make_simple_rail2()
     env = RailEnv(width=rail_map.shape[1],
@@ -91,7 +91,7 @@ def test_dummy_predictor(rendering=False):
     expected_actions = np.array([[0.],
-                                 [1.],
+                                 [2.],
@@ -229,7 +229,7 @@ def test_shortest_path_predictor(rendering=False):
 def test_shortest_path_predictor_conflicts(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map = make_invalid_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
 # -*- coding: utf-8 -*-
 import numpy as np
-from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgent
@@ -51,15 +50,6 @@ def test_save_load():
 def test_rail_environment_single_agent():
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
     # We instantiate the following map on a 3x3 grid
     #  _  _
     # / \/ \
@@ -67,6 +57,7 @@ def test_rail_environment_single_agent():
     # \_/\_/
     transitions = RailEnvTransitions()
+    cells = transitions.transition_list
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
@@ -142,7 +133,7 @@ test_rail_environment_single_agent()
 def test_dead_end():
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
     straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
     straight_horizontal = transitions.rotate_transition(straight_vertical,
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
+from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer
+from flatland.utils.rendertools import RenderTool
+def test_sparse_rail_generator():
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                                       num_intersections=10,  # Number of interesections in map
+                                                       num_trainstations=50,  # Number of possible start/targets on map
+                                                       min_node_dist=6,  # Minimal distance of nodes
+                                                       node_radius=3,  # Proximity of stations to city center
+                                                       num_neighb=3,  # Number of connections to other cities
+                                                       seed=5,  # Random seed
+                                                       realistic_mode=False  # Ordered distribution of nodes
+                                                       ),
+                  agent_generator=sparse_rail_generator_agents_placer(),
+                  number_of_agents=10,
+                  obs_builder_object=GlobalObsForRailEnv())
+    # reset to initialize agents_static
+    env_renderer = RenderTool(env, gl="PILSVG", )
+    env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+    env_renderer.gl.save_image("./sparse_generator_false.png")