diff --git a/docs/flatland.rst b/docs/flatland.rst
index 88e8ec93fd4f6c89c1f6e20c55defbddaa9b28fa..e09087a49b6df3572ac38b77e41ca739bcea8150 100644
--- a/docs/flatland.rst
+++ b/docs/flatland.rst
@@ -6,10 +6,10 @@ Subpackages
 
 .. toctree::
 
-    flatland.core
-    flatland.envs
-    flatland.evaluators
-    flatland.utils
+   flatland.core
+   flatland.envs
+   flatland.evaluators
+   flatland.utils
 
 Submodules
 ----------
@@ -18,15 +18,15 @@ flatland.cli module
 -------------------
 
 .. automodule:: flatland.cli
-    :members:
-    :undoc-members:
-    :show-inheritance:
+   :members:
+   :undoc-members:
+   :show-inheritance:
 
 
 Module contents
 ---------------
 
 .. automodule:: flatland
-    :members:
-    :undoc-members:
-    :show-inheritance:
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py
index 44e4b534c2f2dfa63cab385d009c9afd92285f48..49e550b195ffb15e6554413069369378e80e5f82 100644
--- a/examples/complex_rail_benchmark.py
+++ b/examples/complex_rail_benchmark.py
@@ -3,8 +3,9 @@ import random
 
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 
 
 def run_benchmark():
@@ -15,6 +16,7 @@ def run_benchmark():
     # Example generate a random rail
     env = RailEnv(width=15, height=15,
                   rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=5)
 
     n_trials = 20
diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 723bb1102092c7d48bd938bbd60d0c5213ffecf6..8b1de6aa4e303469d30983d30333fbfda89c1d1e 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -5,10 +5,11 @@ import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.generators import random_rail_generator, complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
@@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder):
     Simplest observation builder. The object returns observation vectors with 5 identical components,
     all equal to the ID of the respective agent.
     """
+
     def __init__(self):
         self.observation_space = [5]
 
@@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
+
     def __init__(self):
         super().__init__(max_depth=0)
         self.observation_space = [3]
@@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 env = RailEnv(width=7,
               height=7,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=1,
               obs_builder_object=SingleAgentNavigationObs())
 
@@ -97,8 +101,8 @@ obs = env.reset()
 env_renderer = RenderTool(env, gl="PILSVG")
 env_renderer.render_env(show=True, frames=True, show_observations=True)
 for step in range(100):
-    action = np.argmax(obs[0])+1
-    obs, all_rewards, done, _ = env.step({0:action})
+    action = np.argmax(obs[0]) + 1
+    obs, all_rewards, done, _ = env.step({0: action})
     print("Rewards: ", all_rewards, "  [done=", done, "]")
     env_renderer.render_env(show=True, frames=True, show_observations=True)
     time.sleep(0.1)
@@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor)
 env = RailEnv(width=10,
               height=10,
               rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=3,
               obs_builder_object=CustomObsBuilder)
 
diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index 515d6c1b0469b7fbd9bad8cd82a40db7766f6219..6162b918734eb311752675e75e203d90e5558c1c 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -1,30 +1,41 @@
 import random
+from typing import Any
 
 import numpy as np
 
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
+from flatland.envs.schedule_generators import ScheduleGenerator, ScheduleGeneratorProduct
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
 np.random.seed(100)
 
 
-def custom_rail_generator():
-    def generator(width, height, num_agents=0, num_resets=0):
+def custom_rail_generator() -> RailGenerator:
+    def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
         rail_array.fill(0)
         new_tran = rail_trans.set_transition(1, 1, 1, 1)
         print(new_tran)
+        rail_array[0, 0] = new_tran
+        rail_array[0, 1] = new_tran
+        return grid_map, None
+
+    return generator
+
+
+def custom_schedule_generator() -> ScheduleGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
         agents_positions = []
         agents_direction = []
         agents_target = []
-        rail_array[0, 0] = new_tran
-        rail_array[0, 1] = new_tran
-        return grid_map, agents_positions, agents_direction, agents_target
+        speeds = []
+        return agents_positions, agents_direction, agents_target, speeds
 
     return generator
 
@@ -32,6 +43,7 @@ def custom_rail_generator():
 env = RailEnv(width=6,
               height=4,
               rail_generator=custom_rail_generator(),
+              schedule_generator=custom_schedule_generator(),
               number_of_agents=1)
 
 env.reset()
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 2c0f814576caef84471d20c91dd92d23d4db02ac..50ea74b84ac9851e88e48bcd32b914e69bc7dd34 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,14 +3,16 @@ import time
 
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(1)
 np.random.seed(1)
 
+
 class SingleAgentNavigationObs(TreeObsForRailEnv):
     """
     We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
@@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
+
     def __init__(self):
         super().__init__(max_depth=0)
         self.observation_space = [3]
@@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 env = RailEnv(width=14,
               height=14,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=2,
               obs_builder_object=SingleAgentNavigationObs())
 
@@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False)
 for step in range(100):
     actions = {}
     for i in range(len(obs)):
-        actions[i] = np.argmax(obs[i])+1
+        actions[i] = np.argmax(obs[i]) + 1
 
-    if step%5 == 0:
+    if step % 5 == 0:
         print("Agent halts")
-        actions[0] = 4 # Halt
+        actions[0] = 4  # Halt
 
     obs, all_rewards, done, _ = env.step(actions)
     if env.agents[0].malfunction_data['malfunction'] > 0:
@@ -82,4 +86,3 @@ for step in range(100):
     if done["__all__"]:
         break
 env_renderer.close_window()
-
diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a185c765bcab831e7b104124a164bcf2398b14
--- /dev/null
+++ b/examples/flatland_2_0_example.py
@@ -0,0 +1,120 @@
+import numpy as np
+from flatland.envs.rail_generators import sparse_rail_generator
+
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+
+np.random.seed(1)
+
+# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
+# Training on simple small tasks is the best way to get familiar with the environment
+
+# Use a the malfunction generator to break agents from time to time
+stochastic_data = {'prop_malfunction': 0.5,  # Percentage of defective agents
+                   'malfunction_rate': 30,  # Rate of malfunction occurence
+                   'min_duration': 3,  # Minimal duration of malfunction
+                   'max_duration': 10  # Max duration of malfunction
+                   }
+
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map (where train stations are)
+                                                   num_intersections=1,  # Number of intersections (no start / target)
+                                                   num_trainstations=15,  # Number of possible start/targets on map
+                                                   min_node_dist=3,  # Minimal distance of nodes
+                                                   node_radius=3,  # Proximity of stations to city center
+                                                   num_neighb=2,  # Number of connections to other cities/intersections
+                                                   seed=15,  # Random seed
+                                                   realistic_mode=True,
+                                                   enhance_intersection=True
+                                                   ),
+              schedule_generator=sparse_schedule_generator(),
+              number_of_agents=5,
+              stochastic_data=stochastic_data,  # Malfunction data generator
+              obs_builder_object=TreeObservation)
+
+env_renderer = RenderTool(env, gl="PILSVG", )
+
+
+# Import your own Agent or use RLlib to train agents on Flatland
+# As an example we use a random agent instead
+class RandomAgent:
+
+    def __init__(self, state_size, action_size):
+        self.state_size = state_size
+        self.action_size = action_size
+
+    def act(self, state):
+        """
+        :param state: input is the observation of the agent
+        :return: returns an action
+        """
+        return np.random.choice(np.arange(self.action_size))
+
+    def step(self, memories):
+        """
+        Step function to improve agent by adjusting policy given the observations
+
+        :param memories: SARS Tuple to be
+        :return:
+        """
+        return
+
+    def save(self, filename):
+        # Store the current policy
+        return
+
+    def load(self, filename):
+        # Load a policy
+        return
+
+
+# Initialize the agent with the parameters corresponding to the environment and observation_builder
+# Set action space to 4 to remove stop action
+agent = RandomAgent(218, 4)
+
+# Empty dictionary for all agent action
+action_dict = dict()
+
+print("Start episode...")
+# Reset environment and get initial observations for all agents
+obs = env.reset()
+# Update/Set agent's speed
+for idx in range(env.get_num_agents()):
+    speed = 1.0 / ((idx % 5) + 1.0)
+    env.agents[idx].speed_data["speed"] = speed
+
+# Reset the rendering sytem
+env_renderer.reset()
+
+# Here you can also further enhance the provided observation by means of normalization
+# See training navigation example in the baseline repository
+
+score = 0
+# Run episode
+frame_step = 0
+for step in range(500):
+    # Chose an action for each agent in the environment
+    for a in range(env.get_num_agents()):
+        action = agent.act(obs[a])
+        action_dict.update({a: action})
+
+    # Environment step which returns the observations for all agents, their corresponding
+    # reward and whether their are done
+    next_obs, all_rewards, done, _ = env.step(action_dict)
+    env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
+    frame_step += 1
+    # Update replay buffer and train agent
+    for a in range(env.get_num_agents()):
+        agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
+        score += all_rewards[a]
+
+    obs = next_obs.copy()
+    if done['__all__']:
+        break
+
+print('Episode: Steps {}\t Score = {}'.format(step, score))
diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py
index 7956c34fd4a5b94859a4b64441450afe2114133c..fbadbd657c36fa1dadf0bca65cff3e9cccd269ea 100644
--- a/examples/simple_example_1.py
+++ b/examples/simple_example_1.py
@@ -1,5 +1,5 @@
-from flatland.envs.generators import rail_from_manual_specifications_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_manual_specifications_generator
 from flatland.utils.rendertools import RenderTool
 
 # Example generate a rail given a manual specification,
diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py
index 994c7deda1569b77d4adac8a17fa9ebe14b27ef6..6db9ba5abbd0999ef3896e733516ed6b3e498bae 100644
--- a/examples/simple_example_2.py
+++ b/examples/simple_example_2.py
@@ -2,8 +2,8 @@ import random
 
 import numpy as np
 
-from flatland.envs.generators import random_rail_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import random_rail_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 5aa03d8f95a7079b708baea1e2ddce27e9a46554..6df6d4af3076b3d9659aadbd55296b667dc7d6db 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -2,9 +2,10 @@ import random
 
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(1)
@@ -13,6 +14,7 @@ np.random.seed(1)
 env = RailEnv(width=7,
               height=7,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=2,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
diff --git a/examples/training_example.py b/examples/training_example.py
index d125be1587a56025ba1cd3f78b28ba3976f01fbf..df93479f5a5ee05abfcb1a98b07ef052bffc2bd4 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -1,9 +1,10 @@
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 np.random.seed(1)
@@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
 env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               obs_builder_object=TreeObservation,
               number_of_agents=3)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
+
 # Import your own Agent or use RLlib to train agents on Flatland
 # As an example we use a random agent here
 
diff --git a/flatland/cli.py b/flatland/cli.py
index 32e8d9dc786b0412795694fc985c90aa55fc2e91..47c450dba803fac17bc13979663ef04e4c0db899 100644
--- a/flatland/cli.py
+++ b/flatland/cli.py
@@ -2,29 +2,33 @@
 
 """Console script for flatland."""
 import sys
+import time
+
 import click
 import numpy as np
-import time
-from flatland.envs.generators import complex_rail_generator
+import redis
+
 from flatland.envs.rail_env import RailEnv
-from flatland.utils.rendertools import RenderTool
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.evaluators.service import FlatlandRemoteEvaluationService
-import redis
+from flatland.utils.rendertools import RenderTool
 
 
 @click.command()
 def demo(args=None):
     """Demo script to check installation"""
     env = RailEnv(
-            width=15,
-            height=15,
-            rail_generator=complex_rail_generator(
-                                    nr_start_goal=10,
-                                    nr_extra=1,
-                                    min_dist=8,
-                                    max_dist=99999),
-            number_of_agents=5)
-    
+        width=15,
+        height=15,
+        rail_generator=complex_rail_generator(
+            nr_start_goal=10,
+            nr_extra=1,
+            min_dist=8,
+            max_dist=99999),
+        schedule_generator=complex_schedule_generator(),
+        number_of_agents=5)
+
     env._max_episode_steps = int(15 * (env.width + env.height))
     env_renderer = RenderTool(env)
 
@@ -52,12 +56,12 @@ def demo(args=None):
 
 
 @click.command()
-@click.option('--tests', 
+@click.option('--tests',
               type=click.Path(exists=True),
               help="Path to folder containing Flatland tests",
               required=True
               )
-@click.option('--service_id', 
+@click.option('--service_id',
               default="FLATLAND_RL_SERVICE_ID",
               help="Evaluation Service ID. This has to match the service id on the client.",
               required=False
@@ -70,14 +74,14 @@ def evaluator(tests, service_id):
         raise Exception(
             "\nRedis server does not seem to be running on your localhost.\n"
             "Please ensure that you have a redis server running on your localhost"
-            )
-    
+        )
+
     grader = FlatlandRemoteEvaluationService(
-                test_env_folder=tests,
-                flatland_rl_service_id=service_id,
-                visualize=False,
-                verbose=False
-                )
+        test_env_folder=tests,
+        flatland_rl_service_id=service_id,
+        visualize=False,
+        verbose=False
+    )
     grader.run()
 
 
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 5e0f6cd72e8ca22c80e3576798e5214f8b036558..bb954998688772a7ce69e5228cff3e16d037f2af 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -7,6 +7,7 @@ from importlib_resources import path
 from numpy import array
 
 from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transitions import Transitions
 
 
@@ -298,6 +299,76 @@ class GridTransitionMap(TransitionMap):
         self.height = new_height
         self.grid = new_grid
 
+    def is_dead_end(self, rcPos):
+        """
+        Check if the cell is a dead-end.
+
+        Parameters
+        ----------
+        rcPos: Tuple[int,int]
+            tuple(row, column) with grid coordinate
+        Returns
+        -------
+        boolean
+            True if and only if the cell is a dead-end.
+        """
+        nbits = 0
+        tmp = self.get_full_transitions(rcPos[0], rcPos[1])
+        while tmp > 0:
+            nbits += (tmp & 1)
+            tmp = tmp >> 1
+        return nbits == 1
+
+    def is_simple_turn(self, rcPos):
+        """
+        Check if the cell is a left/right simple turn
+
+        Parameters
+        ----------
+            rcPos: Tuple[int,int]
+                tuple(row, column) with grid coordinate
+        Returns
+        -------
+            boolean
+                True if and only if the cell is a left/right simple turn.
+        """
+        tmp = self.get_full_transitions(rcPos[0], rcPos[1])
+
+        def is_simple_turn(trans):
+            all_simple_turns = set()
+            for trans in [int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
+                          int('0001001000000000', 2)  # Case 1c (9)  - simple turn left]:
+                          ]:
+                for _ in range(3):
+                    trans = self.transitions.rotate_transition(trans, rotation=90)
+                    all_simple_turns.add(trans)
+            return trans in all_simple_turns
+
+        return is_simple_turn(tmp)
+
+    def check_path_exists(self, start, direction, end):
+        # print("_path_exists({},{},{}".format(start, direction, end))
+        # BFS - Check if a path exists between the 2 nodes
+
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            node_position = node[0]
+            node_direction = node[1]
+            if node_position[0] == end[0] and node_position[1] == end[1]:
+                return True
+            if node not in visited:
+                visited.add(node)
+
+                moves = self.get_transitions(node_position[0], node_position[1], node_direction)
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((get_new_position(node_position, move_index),
+                                      move_index))
+
+        return False
+
     def cell_neighbours_valid(self, rcPos, check_this_cell=False):
         """
         Check validity of cell at rcPos = tuple(row, column)
@@ -350,4 +421,124 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
+    def fix_neighbours(self, rcPos, check_this_cell=False):
+        """
+        Check validity of cell at rcPos = tuple(row, column)
+        Checks that:
+        - surrounding cells have inbound transitions for all the
+            outbound transitions of this cell.
+
+        These are NOT checked - see transition.is_valid:
+        - all transitions have the mirror transitions (N->E <=> W->S)
+        - Reverse transitions (N -> S) only exist for a dead-end
+        - a cell contains either no dead-ends or exactly one
+
+        Returns: True (valid) or False (invalid)
+        """
+        cell_transition = self.grid[tuple(rcPos)]
+
+        if check_this_cell:
+            if not self.transitions.is_valid(cell_transition):
+                return False
+
+        gDir2dRC = self.transitions.gDir2dRC  # [[-1,0] = N, [0,1]=E, etc]
+        grcPos = array(rcPos)
+        grcMax = self.grid.shape
+
+        binTrans = self.get_full_transitions(*rcPos)  # 16bit integer - all trans in/out
+        lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8)  # 2 x uint8
+        g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4)  # 4x4 x uint8 binary(0,1)
+        gDirOut = g2binTrans.any(axis=0)  # outbound directions as boolean array (4)
+        giDirOut = np.argwhere(gDirOut)[:, 0]  # valid outbound directions as array of int
+
+        # loop over available outbound directions (indices) for rcPos
+        for iDirOut in giDirOut:
+            gdRC = gDir2dRC[iDirOut]  # row,col increment
+            gPos2 = grcPos + gdRC  # next cell in that direction
+
+            # Check the adjacent cell is within bounds
+            # if not, then this transition is invalid!
+            if np.any(gPos2 < 0):
+                return False
+            if np.any(gPos2 >= grcMax):
+                return False
+
+            # Get the transitions out of gPos2, using iDirOut as the inbound direction
+            # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
+            t4Trans2 = self.get_transitions(*gPos2, iDirOut)
+            if any(t4Trans2):
+                continue
+            else:
+                self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
+                return False
+
+        return True
+
+    def fix_transitions(self, rcPos):
+        """
+        Fixes broken transitions
+        """
+        gDir2dRC = self.transitions.gDir2dRC  # [[-1,0] = N, [0,1]=E, etc]
+        grcPos = array(rcPos)
+        grcMax = self.grid.shape
+
+        # loop over available outbound directions (indices) for rcPos
+        self.set_transitions(rcPos, 0)
+
+        incoming_connections = np.zeros(4)
+        for iDirOut in np.arange(4):
+            gdRC = gDir2dRC[iDirOut]  # row,col increment
+            gPos2 = grcPos + gdRC  # next cell in that direction
+
+            # Check the adjacent cell is within bounds
+            # if not, then ignore it for the count of incoming connections
+            if np.any(gPos2 < 0):
+                continue
+            if np.any(gPos2 >= grcMax):
+                continue
+
+            # Get the transitions out of gPos2, using iDirOut as the inbound direction
+            # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
+            connected = 0
+            for orientation in range(4):
+                connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
+            if connected > 0:
+                incoming_connections[iDirOut] = 1
+
+        number_of_incoming = np.sum(incoming_connections)
+        # Only one incoming direction --> Straight line
+        if number_of_incoming == 1:
+            for direction in range(4):
+                if incoming_connections[direction] > 0:
+                    self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
+        # Connect all incoming connections
+        if number_of_incoming == 2:
+            connect_directions = np.argwhere(incoming_connections > 0)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
+
+        # Find feasible connection fro three entries
+        if number_of_incoming == 3:
+            hole = np.argwhere(incoming_connections < 1)[0][0]
+            connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
+            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1)
+        # Make a cross
+        if number_of_incoming == 4:
+            connect_directions = np.arange(4)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1)
+            self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1)
+        return True
+
+
+def mirror(dir):
+    return (dir + 2) % 4
 # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index dedd76b6bfd04c13ad59092adfefbde6ae98fc18..996bd73ad9de598eb162a937c135681675119ad3 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu
 a GridTransitionMap object.
 """
 
-import numpy as np
-
 from flatland.core.grid.grid4_astar import a_star
-from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position
+from flatland.core.grid.grid4_utils import get_direction, mirror
 
 
 def connect_rail(rail_trans, rail_array, start, end):
@@ -57,79 +55,141 @@ def connect_rail(rail_trans, rail_array, start, end):
     return path
 
 
-def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
+def connect_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # don't set any transition at node yet
+                new_trans = 0
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                # don't set any transition at node yet
+
+                new_trans_e = 0
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+
+        current_dir = new_dir
+    return path
+
+
+def connect_from_nodes(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
     """
-    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
+
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = 0
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
 
-    TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+
+        current_dir = new_dir
+    return path
+
+
+def connect_to_nodes(rail_trans, rail_array, start, end):
     """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    if len(path) < 2:
+        return []
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index + 1]
+        new_dir = get_direction(current_pos, new_pos)
 
-    def _path_exists(rail, start, direction, end):
-        # BFS - Check if a path exists between the 2 nodes
-
-        visited = set()
-        stack = [(start, direction)]
-        while stack:
-            node = stack.pop()
-            if node[0][0] == end[0] and node[0][1] == end[1]:
-                return 1
-            if node not in visited:
-                visited.add(node)
-                moves = rail.get_transitions(node[0][0], node[0][1], node[1])
-                for move_index in range(4):
-                    if moves[move_index]:
-                        stack.append((get_new_position(node[0], move_index),
-                                      move_index))
-
-                # If cell is a dead-end, append previous node with reversed
-                # orientation!
-                nbits = 0
-                tmp = rail.get_full_transitions(node[0][0], node[0][1])
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    stack.append((node[0], (node[1] + 2) % 4))
-
-        return 0
-
-    valid_positions = []
-    for r in range(rail.height):
-        for c in range(rail.width):
-            if rail.get_full_transitions(r, c) > 0:
-                valid_positions.append((r, c))
-
-    re_generate = True
-    while re_generate:
-        agents_position = [
-            valid_positions[i] for i in
-            np.random.choice(len(valid_positions), num_agents)]
-        agents_target = [
-            valid_positions[i] for i in
-            np.random.choice(len(valid_positions), num_agents)]
-
-        # agents_direction must be a direction for which a solution is
-        # guaranteed.
-        agents_direction = [0] * num_agents
-        re_generate = False
-        for i in range(num_agents):
-            valid_movements = []
-            for direction in range(4):
-                position = agents_position[i]
-                moves = rail.get_transitions(position[0], position[1], direction)
-                for move_index in range(4):
-                    if moves[move_index]:
-                        valid_movements.append((direction, move_index))
-
-            valid_starting_directions = []
-            for m in valid_movements:
-                new_position = get_new_position(agents_position[i], m[1])
-                if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
-                    valid_starting_directions.append(m[0])
-
-            if len(valid_starting_directions) == 0:
-                re_generate = True
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
             else:
-                agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
 
-    return agents_position, agents_direction, agents_target
+        if new_pos == end_pos:
+            # setup end pos setup
+            new_trans_e = rail_array[end_pos]
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = 0
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            rail_array[end_pos] = new_trans_e
+
+        current_dir = new_dir
+    return path
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4158675cae63394ed768bfc36aaef9cd5f44da7e..c4fed2e07a97b00b786a7db8eb06af247d3ede8a 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -383,8 +383,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                     elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir(
-                                self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[pre_step][ca] \
+                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
+                                and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -394,7 +395,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                         conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                         for ca in conflicting_agent[0]:
                             if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
-                                self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict:
+                                self.predicted_dir[post_step][ca])] == 1 \
+                                and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 2281282977d8c9d972f13526efa9d96abaf84a52..62efbdc5bc0781c4c7482412dafd98710ed9d14e 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -2,7 +2,7 @@
 Definition of the RailEnv environment.
 """
 # TODO:  _ this is a global method --> utils or remove later
-
+import warnings
 from enum import IntEnum
 
 import msgpack
@@ -11,9 +11,11 @@ import numpy as np
 
 from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
-from flatland.envs.generators import random_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.rail_generators import random_rail_generator, RailGenerator
+from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
 
 m.patch()
 
@@ -91,7 +93,8 @@ class RailEnv(Environment):
     def __init__(self,
                  width,
                  height,
-                 rail_generator=random_rail_generator(),
+                 rail_generator: RailGenerator = random_rail_generator(),
+                 schedule_generator: ScheduleGenerator = random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
                  max_episode_steps=None,
@@ -107,13 +110,12 @@ class RailEnv(Environment):
             height and agents handles of a  rail environment, along with the number of times
             the env has been reset, and returns a GridTransitionMap object and a list of
             starting positions, targets, and initial orientations for agent handle.
-            Implemented functions are:
-                random_rail_generator : generate a random rail of given size
-                rail_from_grid_transition_map(rail_map) : generate a rail from
-                                        a GridTransitionMap object
-                rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
-                                        a rail specifications array
-                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
+            The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
+            Implementations can be found in flatland/envs/rail_generators.py
+        schedule_generator : function
+            The schedule_generator function is a function that takes the grid, the number of agents and optional hints
+            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
+            Implementations can be found in flatland/envs/schedule_generators.py
         width : int
             The width of the rail map. Potentially in the future,
             a range of widths to sample from.
@@ -131,8 +133,10 @@ class RailEnv(Environment):
         file_name: you can load a pickle file.
         """
 
+        self.rail_generator: RailGenerator = rail_generator
+        self.schedule_generator: ScheduleGenerator = schedule_generator
         self.rail_generator = rail_generator
-        self.rail = None
+        self.rail: GridTransitionMap = None
         self.width = width
         self.height = height
 
@@ -213,18 +217,27 @@ class RailEnv(Environment):
             if replace_agents then regenerate the agents static.
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
         """
-        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
+        rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
-        # Check if generator provided a distance map TODO: Make this check safer!
-        if len(tRailAgents) > 5:
-            self.obs_builder.distance_map = tRailAgents[-1]
+        if optionals and 'distance_maps' in optionals:
+            self.obs_builder.distance_map = optionals['distance_maps']
 
         if regen_rail or self.rail is None:
-            self.rail = tRailAgents[0]
+            self.rail = rail
             self.height, self.width = self.rail.grid.shape
+            for r in range(self.height):
+                for c in range(self.width):
+                    rcPos = (r, c)
+                    check = self.rail.cell_neighbours_valid(rcPos, True)
+                    if not check:
+                        warnings.warn("Invalid grid at {} -> {}".format(rcPos, check))
 
         if replace_agents:
-            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
+            agents_hints = None
+            if optionals and 'agents_hints' in optionals:
+                agents_hints = optionals['agents_hints']
+            self.agents_static = EnvAgentStatic.from_lists(
+                *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
 
         self.restart_agents()
 
@@ -258,8 +271,7 @@ class RailEnv(Environment):
         agent.malfunction_data['next_malfunction'] -= 1
 
         # Only agents that have a positive rate for malfunctions and are not currently broken are considered
-        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[
-            'malfunction']:
+        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']:
 
             # If counter has come to zero --> Agent has malfunction
             # set next malfunction time and duration of current malfunction
diff --git a/flatland/envs/generators.py b/flatland/envs/rail_generators.py
similarity index 57%
rename from flatland/envs/generators.py
rename to flatland/envs/rail_generators.py
index 355f5502992a34f8d58d4dbd80028eb4dd71cc48..40ec2e0df89a3de48bf0a2a4430de3e30fa556e5 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/rail_generators.py
@@ -1,3 +1,7 @@
+"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
+import warnings
+from typing import Callable, Tuple, Any, Optional
+
 import msgpack
 import numpy as np
 
@@ -5,29 +9,34 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.grid4_generators_utils import connect_rail
-from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
+from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
+
+RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
+RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 
 
-def empty_rail_generator():
+def empty_rail_generator() -> RailGenerator:
     """
     Returns a generator which returns an empty rail mail with no agents.
     Primarily used by the editor
     """
 
-    def generator(width, height, num_agents=0, num_resets=0):
+    def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
         rail_array.fill(0)
 
-        return grid_map, [], [], [], []
+        return grid_map, None
 
     return generator
 
 
-def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
+def complex_rail_generator(nr_start_goal=1,
+                           nr_extra=100,
+                           min_dist=20,
+                           max_dist=99999,
+                           seed=0) -> RailGenerator:
     """
     Parameters
     -------
@@ -47,8 +56,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         if num_agents > nr_start_goal:
             num_agents = nr_start_goal
             print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
-        rail_trans = RailEnvTransitions()
-        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions())
         rail_array = grid_map.grid
         rail_array.fill(0)
 
@@ -72,6 +80,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         # - return transition map + list of [start_pos, start_dir, goal_pos] points
         #
 
+        rail_trans = grid_map.transitions
         start_goal = []
         start_dir = []
         nr_created = 0
@@ -141,11 +150,10 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
             if len(new_path) >= 2:
                 nr_created += 1
 
-        agents_position = [sg[0] for sg in start_goal[:num_agents]]
-        agents_target = [sg[1] for sg in start_goal[:num_agents]]
-        agents_direction = start_dir[:num_agents]
-
-        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return grid_map, {'agents_hints': {
+            'start_goal': start_goal,
+            'start_dir': start_dir
+        }}
 
     return generator
 
@@ -189,22 +197,18 @@ def rail_from_manual_specifications_generator(rail_spec):
                 effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
                 rail.set_transitions((r, c), effective_transition_cell)
 
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            rail,
-            num_agents)
-
-        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return [rail, None]
 
     return generator
 
 
-def rail_from_file(filename):
+def rail_from_file(filename) -> RailGenerator:
     """
     Utility to load pickle file
 
     Parameters
     -------
-    input_file : Pickle file generated by env.save() or editor
+    filename : Pickle file generated by env.save() or editor
 
     Returns
     -------
@@ -222,26 +226,16 @@ def rail_from_file(filename):
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
-        # agents are always reset as not moving
-        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
-        # setup with loaded data
-        agents_position = [a.position for a in agents_static]
-        agents_direction = [a.direction for a in agents_static]
-        agents_target = [a.target for a in agents_static]
         if b"distance_maps" in data.keys():
             distance_maps = data[b"distance_maps"]
             if len(distance_maps) > 0:
-                return rail, agents_position, agents_direction, agents_target, [1.0] * len(
-                    agents_position), distance_maps
-            else:
-                return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
-        else:
-            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+                return rail, {'distance_maps': distance_maps}
+        return [rail, None]
 
     return generator
 
 
-def rail_from_grid_transition_map(rail_map):
+def rail_from_grid_transition_map(rail_map) -> RailGenerator:
     """
     Utility to convert a rail given by a GridTransitionMap map with the correct
     16-bit transitions specifications.
@@ -257,17 +251,13 @@ def rail_from_grid_transition_map(rail_map):
         Generator function that always returns the given `rail_map' object.
     """
 
-    def generator(width, height, num_agents, num_resets=0):
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            rail_map,
-            num_agents)
-
-        return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+    def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
+        return rail_map, None
 
     return generator
 
 
-def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
+def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator:
     """
     Dummy random level generator:
     - fill in cells at random in [width-2, height-2]
@@ -299,7 +289,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
         The matrix with the correct 16-bit bitmaps for each cell.
     """
 
-    def generator(width, height, num_agents, num_resets=0):
+    def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
         t_utils = RailEnvTransitions()
 
         transition_probability = cell_type_relative_proportion
@@ -531,10 +521,277 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
         return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
         return_rail.grid = tmp_rail
 
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            return_rail,
-            num_agents)
+        return return_rail, None
+
+    return generator
+
+
+def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2,
+                          num_neighb=3, realistic_mode=False, enhance_intersection=False, seed=0):
+    """
+    This is a level generator which generates complex sparse rail configurations
+
+    :param num_cities: Number of city node (can hold trainstations)
+    :param num_intersections: Number of intersection that city nodes can connect to
+    :param num_trainstations: Total number of trainstations in env
+    :param min_node_dist: Minimal distance between nodes
+    :param node_radius: Proximity of trainstations to center of city node
+    :param num_neighb: Number of neighbouring nodes each node connects to
+    :param realistic_mode: True -> NOdes evenly distirbuted in env, False-> Random distribution of nodes
+    :param enhance_intersection: True -> Extra rail elements added at intersections
+    :param seed: Random Seed
+    :return:
+        -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+
+    def generator(width, height, num_agents, num_resets=0):
+
+        if num_agents > num_trainstations:
+            num_agents = num_trainstations
+            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
+
+        rail_trans = RailEnvTransitions()
+        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        rail_array = grid_map.grid
+        rail_array.fill(0)
+        np.random.seed(seed + num_resets)
+
+        # Generate a set of nodes for the sparse network
+        # Try to connect cities to nodes first
+        node_positions = []
+        city_positions = []
+        intersection_positions = []
+
+        # Evenly distribute cities and intersections
+        if realistic_mode:
+            tot_num_node = num_intersections + num_cities
+            nodes_ratio = height / width
+            nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio)))
+            nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
+            x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
+            y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
+
+        for node_idx in range(num_cities + num_intersections):
+            to_close = True
+            tries = 0
+            if not realistic_mode:
+                while to_close:
+                    x_tmp = node_radius + np.random.randint(height - node_radius)
+                    y_tmp = node_radius + np.random.randint(width - node_radius)
+                    to_close = False
+
+                    # Check distance to cities
+                    for node_pos in city_positions:
+                        if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                            to_close = True
+
+                    # CHeck distance to intersections
+                    for node_pos in intersection_positions:
+                        if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                            to_close = True
+
+                    if not to_close:
+                        node_positions.append((x_tmp, y_tmp))
+                        if node_idx < num_cities:
+                            city_positions.append((x_tmp, y_tmp))
+                        else:
+                            intersection_positions.append((x_tmp, y_tmp))
+                    tries += 1
+                    if tries > 100:
+                        warnings.warn("Could not set nodes, please change initial parameters!!!!")
+                        break
+            else:
+                x_tmp = x_positions[node_idx % nodes_per_row]
+                y_tmp = y_positions[node_idx // nodes_per_row]
+                if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0:
+                    city_positions.append((x_tmp, y_tmp))
+                else:
+                    intersection_positions.append((x_tmp, y_tmp))
+
+        node_positions = city_positions + intersection_positions
+
+        # Chose node connection
+        # Set up list of available nodes to connect to
+        available_nodes_full = np.arange(num_cities + num_intersections)
+        available_cities = np.arange(num_cities)
+        available_intersections = np.arange(num_cities, num_cities + num_intersections)
+
+        # Start at some node
+        current_node = np.random.randint(len(available_nodes_full))
+        node_stack = [current_node]
+        allowed_connections = num_neighb
+        first_node = True
+        while len(node_stack) > 0:
+            current_node = node_stack[0]
+            delete_idx = np.where(available_nodes_full == current_node)
+            available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
+            # Priority city to intersection connections
+            if current_node < num_cities and len(available_intersections) > 0:
+                available_nodes = available_intersections
+                delete_idx = np.where(available_cities == current_node)
+                available_cities = np.delete(available_cities, delete_idx, 0)
+
+            # Priority intersection to city connections
+            elif current_node >= num_cities and len(available_cities) > 0:
+                available_nodes = available_cities
+                delete_idx = np.where(available_intersections == current_node)
+                available_intersections = np.delete(available_intersections, delete_idx, 0)
+
+            # If no options possible connect to whatever node is still available
+            else:
+                available_nodes = available_nodes_full
+
+            # Sort available neighbors according to their distance.
+            node_dist = []
+            for av_node in available_nodes:
+                node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
+            available_nodes = available_nodes[np.argsort(node_dist)]
+
+            # Set number of neighboring nodes
+            if len(available_nodes) >= allowed_connections:
+                connected_neighb_idx = available_nodes[:allowed_connections]
+            else:
+                connected_neighb_idx = available_nodes
+
+            # Less connections for subsequent nodes
+            if first_node:
+                allowed_connections -= 1
+                first_node = False
+
+            # Connect to the neighboring nodes
+            for neighb in connected_neighb_idx:
+                if neighb not in node_stack:
+                    node_stack.append(neighb)
+                connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
+            node_stack.pop(0)
+
+        # Place train stations close to the node
+        # We currently place them uniformly distirbuted among all cities
+        if num_cities > 1:
+            train_stations = [[] for i in range(num_cities)]
+            built_num_trainstation = 0
+            spot_found = True
+            for station in range(num_trainstations):
+                trainstation_node = int(station / num_trainstations * num_cities)
+
+                station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
+                                    0,
+                                    height - 1)
+                station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
+                                    0,
+                                    width - 1)
+                tries = 0
+                while (station_x, station_y) in train_stations \
+                    or (station_x, station_y) == node_positions[trainstation_node] \
+                    or rail_array[(station_x, station_y)] != 0:  # noqa: E125
+
+                    station_x = np.clip(
+                        node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
+                        0,
+                        height - 1)
+                    station_y = np.clip(
+                        node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
+                        0,
+                        width - 1)
+                    tries += 1
+                    if tries > 100:
+                        warnings.warn("Could not set trainstations, please change initial parameters!!!!")
+                        spot_found = False
+                        break
+                if spot_found:
+                    train_stations[trainstation_node].append((station_x, station_y))
+
+                # Connect train station to the correct node
+                connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
+                                                (station_x, station_y))
+                # Check if connection was made
+                if len(connection) == 0:
+                    train_stations[trainstation_node].pop(-1)
+                else:
+                    built_num_trainstation += 1
+
+        # Adjust the number of agents if you could not build enough trainstations
+
+        if num_agents > built_num_trainstation:
+            num_agents = built_num_trainstation
+            warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
+
+        # Place passing lanes at intersections
+        # We currently place them uniformly distirbuted among all cities
+        if enhance_intersection:
+
+            for intersection in range(num_intersections):
+                intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3),
+                                        1,
+                                        height - 2)
+                intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3),
+                                        2,
+                                        width - 2)
+                intersect_x_2 = np.clip(
+                    intersection_positions[intersection][0] + np.random.randint(-3, -1),
+                    1,
+                    height - 2)
+                intersect_y_2 = np.clip(
+                    intersection_positions[intersection][1] + np.random.randint(-3, 3),
+                    1,
+                    width - 2)
+
+                # Connect train station to the correct node
+                connect_nodes(rail_trans, rail_array, (intersect_x_1, intersect_y_1),
+                              (intersect_x_2, intersect_y_2))
+                connect_nodes(rail_trans, rail_array, intersection_positions[intersection],
+                              (intersect_x_1, intersect_y_1))
+                connect_nodes(rail_trans, rail_array, intersection_positions[intersection],
+                              (intersect_x_2, intersect_y_2))
+                grid_map.fix_transitions((intersect_x_1, intersect_y_1))
+                grid_map.fix_transitions((intersect_x_2, intersect_y_2))
+
+        # Fix all nodes with illegal transition maps
+        for current_node in node_positions:
+            grid_map.fix_transitions(current_node)
+
+        # Generate start and target node directory for all agents.
+        # Assure that start and target are not in the same node
+        agent_start_targets_nodes = []
+
+        # Slot availability in node
+        node_available_start = []
+        node_available_target = []
+        for node_idx in range(num_cities):
+            node_available_start.append(len(train_stations[node_idx]))
+            node_available_target.append(len(train_stations[node_idx]))
+
+        # Assign agents to slots
+        for agent_idx in range(num_agents):
+            avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
+            avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
+            start_node = np.random.choice(avail_start_nodes)
+            target_node = np.random.choice(avail_target_nodes)
+            tries = 0
+            found_agent_pair = True
+            while target_node == start_node:
+                target_node = np.random.choice(avail_target_nodes)
+                tries += 1
+                # Test again with new start node if no pair is found (This code needs to be improved)
+                if (tries + 1) % 10 == 0:
+                    start_node = np.random.choice(avail_start_nodes)
+                if tries > 100:
+                    warnings.warn("Could not set trainstations, removing agent!")
+                    found_agent_pair = False
+                    break
+            if found_agent_pair:
+                node_available_start[start_node] -= 1
+                node_available_target[target_node] -= 1
+                agent_start_targets_nodes.append((start_node, target_node))
+            else:
+                num_agents -= 1
 
-        return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return grid_map, {'agents_hints': {
+            'num_agents': num_agents,
+            'agent_start_targets_nodes': agent_start_targets_nodes,
+            'train_stations': train_stations
+        }}
 
     return generator
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..4843e0040d80b79de54e8ed57674a37884ef6809
--- /dev/null
+++ b/flatland/envs/schedule_generators.py
@@ -0,0 +1,235 @@
+"""Schedule generators (railway undertaking, "EVU")."""
+import warnings
+from typing import Tuple, List, Callable, Mapping, Optional, Any
+
+import msgpack
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgentStatic
+
+AgentPosition = Tuple[int, int]
+ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
+ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct]
+
+
+def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]:
+    """
+    Parameters
+    -------
+    nb_agents : int
+        The number of agents to generate a speed for
+    speed_ratio_map : Mapping[float,float]
+        A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
+
+    Returns
+    -------
+    List[float]
+        A list of size nb_agents of speeds with the corresponding probabilistic ratios.
+    """
+    if speed_ratio_map is None:
+        return [1.0] * nb_agents
+
+    nb_classes = len(speed_ratio_map.keys())
+    speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
+    speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
+    speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
+    return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
+
+
+def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        start_goal = hints['start_goal']
+        start_dir = hints['start_dir']
+        agents_position = [sg[0] for sg in start_goal[:num_agents]]
+        agents_target = [sg[1] for sg in start_goal[:num_agents]]
+        agents_direction = start_dir[:num_agents]
+
+        if speed_ratio_map:
+            speeds = speed_initialization_helper(num_agents, speed_ratio_map)
+        else:
+            speeds = [1.0] * len(agents_position)
+
+        return agents_position, agents_direction, agents_target, speeds
+
+    return generator
+
+
+def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        train_stations = hints['train_stations']
+        agent_start_targets_nodes = hints['agent_start_targets_nodes']
+        num_agents = hints['num_agents']
+        # Place agents and targets within available train stations
+        agents_position = []
+        agents_target = []
+        agents_direction = []
+        for agent_idx in range(num_agents):
+            # Set target for agent
+            current_target_node = agent_start_targets_nodes[agent_idx][1]
+            target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+            target = train_stations[current_target_node][target_station_idx]
+            tries = 0
+            while (target[0], target[1]) in agents_target:
+                target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+                target = train_stations[current_target_node][target_station_idx]
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set target position, removing an agent")
+                    break
+            agents_target.append((target[0], target[1]))
+
+            # Set start for agent
+            current_start_node = agent_start_targets_nodes[agent_idx][0]
+            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+            start = train_stations[current_start_node][start_station_idx]
+            tries = 0
+            while (start[0], start[1]) in agents_position:
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set start position, please change initial parameters!!!!")
+                    break
+                start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+                start = train_stations[current_start_node][start_station_idx]
+
+            agents_position.append((start[0], start[1]))
+
+            # Orient the agent correctly
+            for orientation in range(4):
+                transitions = rail.get_transitions(start[0], start[1], orientation)
+                if any(transitions) > 0:
+                    agents_direction.append(orientation)
+                    continue
+
+        if speed_ratio_map:
+            speeds = speed_initialization_helper(num_agents, speed_ratio_map)
+        else:
+            speeds = [1.0] * len(agents_position)
+
+        return agents_position, agents_direction, agents_target, speeds
+
+    return generator
+
+
+def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
+    """
+    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
+
+    Parameters
+    -------
+        rail : GridTransitionMap
+            The railway to place agents on.
+        num_agents : int
+            The number of agents to generate a speed for
+        speed_ratio_map : Mapping[float,float]
+            A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
+    Returns
+    -------
+        Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
+        initial positions, directions, targets speeds
+    """
+
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
+
+        valid_positions = []
+        for r in range(rail.height):
+            for c in range(rail.width):
+                if rail.get_full_transitions(r, c) > 0:
+                    valid_positions.append((r, c))
+        if len(valid_positions) == 0:
+            return [], [], [], []
+
+        if len(valid_positions) < num_agents:
+            warnings.warn("schedule_generators: len(valid_positions) < num_agents")
+            return [], [], [], []
+
+        agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
+        agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
+        agents_target_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
+        agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)]
+        update_agents = np.zeros(num_agents)
+
+        re_generate = True
+        cnt = 0
+        while re_generate:
+            cnt += 1
+            if cnt > 1:
+                print("re_generate cnt={}".format(cnt))
+            if cnt > 1000:
+                raise Exception("After 1000 re_generates still not success, giving up.")
+            # update position
+            for i in range(num_agents):
+                if update_agents[i] == 1:
+                    x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx)
+                    agents_position_idx[i] = np.random.choice(x)
+                    agents_position[i] = valid_positions[agents_position_idx[i]]
+                    x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx)
+                    agents_target_idx[i] = np.random.choice(x)
+                    agents_target[i] = valid_positions[agents_target_idx[i]]
+            update_agents = np.zeros(num_agents)
+
+            # agents_direction must be a direction for which a solution is
+            # guaranteed.
+            agents_direction = [0] * num_agents
+            re_generate = False
+            for i in range(num_agents):
+                valid_movements = []
+                for direction in range(4):
+                    position = agents_position[i]
+                    moves = rail.get_transitions(position[0], position[1], direction)
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            valid_movements.append((direction, move_index))
+
+                valid_starting_directions = []
+                for m in valid_movements:
+                    new_position = get_new_position(agents_position[i], m[1])
+                    if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1],
+                                                                                        agents_target[i]):
+                        valid_starting_directions.append(m[0])
+
+                if len(valid_starting_directions) == 0:
+                    update_agents[i] = 1
+                    warnings.warn("reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i]))
+                    re_generate = True
+                    break
+                else:
+                    agents_direction[i] = valid_starting_directions[
+                        np.random.choice(len(valid_starting_directions), 1)[0]]
+
+        agents_speed = speed_initialization_helper(num_agents, speed_ratio_map)
+        return agents_position, agents_direction, agents_target, agents_speed
+
+    return generator
+
+
+def schedule_from_file(filename) -> ScheduleGenerator:
+    """
+    Utility to load pickle file
+
+    Parameters
+    -------
+    input_file : Pickle file generated by env.save() or editor
+
+    Returns
+    -------
+    Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
+        initial positions, directions, targets speeds
+    """
+
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+        data = msgpack.unpackb(load_data, use_list=False)
+
+        # agents are always reset as not moving
+        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
+        # setup with loaded data
+        agents_position = [a.position for a in agents_static]
+        agents_direction = [a.direction for a in agents_static]
+        agents_target = [a.target for a in agents_static]
+
+        return agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+
+    return generator
diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index a4968c0c8e827c060e0e3f7de0cf28cc0658089b..f2dac1c8705a651e2a7be026b4bd82a961efbbdd 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -1,18 +1,21 @@
-import redis
+import hashlib
 import json
+import logging
 import os
-import numpy as np
+import random
+import time
+
 import msgpack
 import msgpack_numpy as m
-import hashlib
-import random
-from flatland.evaluators import messages
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.generators import rail_from_file
+import numpy as np
+import redis
+
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-import time
-import logging
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_file
+from flatland.evaluators import messages
+
 logger = logging.getLogger(__name__)
 logger.setLevel(logging.INFO)
 m.patch()
@@ -22,8 +25,8 @@ def are_dicts_equal(d1, d2):
     """ return True if all keys and values are the same """
     return all(k in d2 and d1[k] == d2[k]
                for k in d1) \
-        and all(k in d1 and d1[k] == d2[k]
-               for k in d2)
+           and all(k in d1 and d1[k] == d2[k]
+                   for k in d2)
 
 
 class FlatlandRemoteClient(object):
@@ -41,39 +44,40 @@ class FlatlandRemoteClient(object):
         where `service_id` is either provided as an `env` variable or is
         instantiated to "flatland_rl_redis_service_id"
     """
-    def __init__(self,  
-                remote_host='127.0.0.1',
-                remote_port=6379,
-                remote_db=0,
-                remote_password=None,
-                test_envs_root=None,
-                verbose=False):
+
+    def __init__(self,
+                 remote_host='127.0.0.1',
+                 remote_port=6379,
+                 remote_db=0,
+                 remote_password=None,
+                 test_envs_root=None,
+                 verbose=False):
 
         self.remote_host = remote_host
         self.remote_port = remote_port
         self.remote_db = remote_db
         self.remote_password = remote_password
         self.redis_pool = redis.ConnectionPool(
-                                host=remote_host,
-                                port=remote_port,
-                                db=remote_db,
-                                password=remote_password)
+            host=remote_host,
+            port=remote_port,
+            db=remote_db,
+            password=remote_password)
         self.namespace = "flatland-rl"
         self.service_id = os.getenv(
-                            'FLATLAND_RL_SERVICE_ID',
-                            'FLATLAND_RL_SERVICE_ID'
-                            )
+            'FLATLAND_RL_SERVICE_ID',
+            'FLATLAND_RL_SERVICE_ID'
+        )
         self.command_channel = "{}::{}::commands".format(
-                                    self.namespace,
-                                    self.service_id
-                                )
+            self.namespace,
+            self.service_id
+        )
         if test_envs_root:
             self.test_envs_root = test_envs_root
         else:
             self.test_envs_root = os.getenv(
-                                'AICROWD_TESTS_FOLDER',
-                                '/tmp/flatland_envs'
-                                )
+                'AICROWD_TESTS_FOLDER',
+                '/tmp/flatland_envs'
+            )
 
         self.verbose = verbose
 
@@ -85,12 +89,12 @@ class FlatlandRemoteClient(object):
 
     def _generate_response_channel(self):
         random_hash = hashlib.md5(
-                        "{}".format(
-                                random.randint(0, 10**10)
-                            ).encode('utf-8')).hexdigest()
+            "{}".format(
+                random.randint(0, 10 ** 10)
+            ).encode('utf-8')).hexdigest()
         response_channel = "{}::{}::response::{}".format(self.namespace,
-                                                        self.service_id,
-                                                        random_hash)
+                                                         self.service_id,
+                                                         random_hash)
         return response_channel
 
     def _blocking_request(self, _request):
@@ -124,9 +128,9 @@ class FlatlandRemoteClient(object):
         if self.verbose:
             print("Response : ", _response)
         _response = msgpack.unpackb(
-                        _response, 
-                        object_hook=m.decode, 
-                        encoding="utf8")
+            _response,
+            object_hook=m.decode,
+            encoding="utf8")
         if _response['type'] == messages.FLATLAND_RL.ERROR:
             raise Exception(str(_response["payload"]))
         else:
@@ -181,7 +185,7 @@ class FlatlandRemoteClient(object):
                 "Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
                 "to point to the location of the Tests folder ? \n"
                 "We are currently looking at `{}` for the tests".format(self.test_envs_root)
-                )
+            )
         print("Current env path : ", test_env_file_path)
         self.env = RailEnv(
             width=1,
@@ -207,7 +211,7 @@ class FlatlandRemoteClient(object):
         _request['payload']['action'] = action
         _response = self._blocking_request(_request)
         _payload = _response['payload']
-        
+
         # remote_observation = _payload['observation']
         remote_reward = _payload['reward']
         remote_done = _payload['done']
@@ -216,14 +220,14 @@ class FlatlandRemoteClient(object):
         # Replicate the action in the local env
         local_observation, local_reward, local_done, local_info = \
             self.env.step(action)
-        
+
         print(local_reward)
         if not are_dicts_equal(remote_reward, local_reward):
             raise Exception("local and remote `reward` are diverging")
             print(remote_reward, local_reward)
         if not are_dicts_equal(remote_done, local_done):
             raise Exception("local and remote `done` are diverging")
-        
+
         # Return local_observation instead of remote_observation
         # as the remote_observation is build using a dummy observation
         # builder
@@ -250,21 +254,23 @@ class FlatlandRemoteClient(object):
 if __name__ == "__main__":
     remote_client = FlatlandRemoteClient()
 
+
     def my_controller(obs, _env):
         _action = {}
         for _idx, _ in enumerate(_env.agents):
             _action[_idx] = np.random.randint(0, 5)
         return _action
-    
+
+
     my_observation_builder = TreeObsForRailEnv(max_depth=3,
-                                predictor=ShortestPathPredictorForRailEnv())
+                                               predictor=ShortestPathPredictorForRailEnv())
 
     episode = 0
     obs = True
-    while obs:        
+    while obs:
         obs = remote_client.env_create(
-                    obs_builder_object=my_observation_builder
-                    )
+            obs_builder_object=my_observation_builder
+        )
         if not obs:
             """
             The remote env returns False as the first obs
@@ -285,7 +291,5 @@ if __name__ == "__main__":
                 print("Reward : ", sum(list(all_rewards.values())))
                 break
 
-    print("Evaluation Complete...")       
+    print("Evaluation Complete...")
     print(remote_client.submit())
-
-
diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py
index 3ad0a97598c8beb66fc164eb45b670c87f3c96f9..8967b52d9d6ee70a7eb8af257ef6b4e25b531314 100644
--- a/flatland/evaluators/service.py
+++ b/flatland/evaluators/service.py
@@ -1,24 +1,26 @@
 #!/usr/bin/env python
 from __future__ import print_function
-import redis
-from flatland.envs.generators import rail_from_file
-from flatland.envs.rail_env import RailEnv
-from flatland.core.env_observation_builder import DummyObservationBuilder
-from flatland.evaluators import messages
-from flatland.evaluators import aicrowd_helpers
-from flatland.utils.rendertools import RenderTool
-import numpy as np
-import msgpack
-import msgpack_numpy as m
-import os
+
 import glob
+import os
+import random
 import shutil
 import time
 import traceback
+
 import crowdai_api
+import msgpack
+import msgpack_numpy as m
+import numpy as np
+import redis
 import timeout_decorator
-import random
 
+from flatland.core.env_observation_builder import DummyObservationBuilder
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_file
+from flatland.evaluators import aicrowd_helpers
+from flatland.evaluators import messages
+from flatland.utils.rendertools import RenderTool
 
 use_signals_in_timeout = True
 if os.name == 'nt':
@@ -35,7 +37,7 @@ m.patch()
 ########################################################
 # CONSTANTS
 ########################################################
-PER_STEP_TIMEOUT = 10*60  # 5 minutes
+PER_STEP_TIMEOUT = 10 * 60  # 5 minutes
 
 
 class FlatlandRemoteEvaluationService:
@@ -59,17 +61,18 @@ class FlatlandRemoteEvaluationService:
     unpacked with `msgpack` (a patched version of msgpack which also supports
     numpy arrays).
     """
+
     def __init__(self,
-                test_env_folder="/tmp",
-                flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
-                remote_host='127.0.0.1',
-                remote_port=6379,
-                remote_db=0,
-                remote_password=None,
-                visualize=False,
-                video_generation_envs=[],
-                report=None,
-                verbose=False):
+                 test_env_folder="/tmp",
+                 flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
+                 remote_host='127.0.0.1',
+                 remote_port=6379,
+                 remote_db=0,
+                 remote_password=None,
+                 visualize=False,
+                 video_generation_envs=[],
+                 report=None,
+                 verbose=False):
 
         # Test Env folder Paths
         self.test_env_folder = test_env_folder
@@ -83,15 +86,15 @@ class FlatlandRemoteEvaluationService:
         # Logging and Reporting related vars
         self.verbose = verbose
         self.report = report
-        
+
         # Communication Protocol Related vars
         self.namespace = "flatland-rl"
         self.service_id = flatland_rl_service_id
         self.command_channel = "{}::{}::commands".format(
-                                    self.namespace, 
-                                    self.service_id
-                                )
-        
+            self.namespace,
+            self.service_id
+        )
+
         # Message Broker related vars
         self.remote_host = remote_host
         self.remote_port = remote_port
@@ -114,7 +117,7 @@ class FlatlandRemoteEvaluationService:
                 "normalized_reward": 0.0
             }
         }
-        
+
         # RailEnv specific variables
         self.env = False
         self.env_renderer = False
@@ -156,7 +159,7 @@ class FlatlandRemoteEvaluationService:
             ├── .......
             ├── .......
             └── Level_99.pkl 
-        """            
+        """
         env_paths = sorted(glob.glob(
             os.path.join(
                 self.test_env_folder,
@@ -179,16 +182,16 @@ class FlatlandRemoteEvaluationService:
         """
         if self.verbose or self.report:
             print("Attempting to connect to redis server at {}:{}/{}".format(
-                    self.remote_host, 
-                    self.remote_port, 
-                    self.remote_db))
+                self.remote_host,
+                self.remote_port,
+                self.remote_db))
 
         self.redis_pool = redis.ConnectionPool(
-                            host=self.remote_host, 
-                            port=self.remote_port, 
-                            db=self.remote_db, 
-                            password=self.remote_password
-                        )
+            host=self.remote_host,
+            port=self.remote_port,
+            db=self.remote_db,
+            password=self.remote_password
+        )
 
     def get_redis_connection(self):
         """
@@ -200,13 +203,13 @@ class FlatlandRemoteEvaluationService:
             redis_conn.ping()
         except Exception as e:
             raise Exception(
-                    "Unable to connect to redis server at {}:{} ."
-                    "Are you sure there is a redis-server running at the "
-                    "specified location ?".format(
-                        self.remote_host,
-                        self.remote_port
-                        )
-                    )
+                "Unable to connect to redis server at {}:{} ."
+                "Are you sure there is a redis-server running at the "
+                "specified location ?".format(
+                    self.remote_host,
+                    self.remote_port
+                )
+            )
         return redis_conn
 
     def _error_template(self, payload):
@@ -220,8 +223,8 @@ class FlatlandRemoteEvaluationService:
         return _response
 
     @timeout_decorator.timeout(
-                        PER_STEP_TIMEOUT,
-                        use_signals=use_signals_in_timeout)  # timeout for each command
+        PER_STEP_TIMEOUT,
+        use_signals=use_signals_in_timeout)  # timeout for each command
     def _get_next_command(self, _redis):
         """
         A low level wrapper for obtaining the next command from a 
@@ -231,7 +234,7 @@ class FlatlandRemoteEvaluationService:
         """
         command = _redis.brpop(self.command_channel)[1]
         return command
-    
+
     def get_next_command(self):
         """
         A helper function to obtain the next command, which transparently 
@@ -246,18 +249,18 @@ class FlatlandRemoteEvaluationService:
                 print("Command Service: ", command)
         except timeout_decorator.timeout_decorator.TimeoutError:
             raise Exception(
-                    "Timeout in step {} of simulation {}".format(
-                            self.current_step,
-                            self.simulation_count
-                            ))
+                "Timeout in step {} of simulation {}".format(
+                    self.current_step,
+                    self.simulation_count
+                ))
         command = msgpack.unpackb(
-                    command, 
-                    object_hook=m.decode, 
-                    encoding="utf8"
-                )
+            command,
+            object_hook=m.decode,
+            encoding="utf8"
+        )
         if self.verbose:
             print("Received Request : ", command)
-        
+
         return command
 
     def send_response(self, _command_response, command, suppress_logs=False):
@@ -266,15 +269,15 @@ class FlatlandRemoteEvaluationService:
 
         if self.verbose and not suppress_logs:
             print("Responding with : ", _command_response)
-        
+
         _redis.rpush(
-            command_response_channel, 
+            command_response_channel,
             msgpack.packb(
-                _command_response, 
-                default=m.encode, 
+                _command_response,
+                default=m.encode,
                 use_bin_type=True)
         )
-        
+
     def handle_ping(self, command):
         """
         Handles PING command from the client.
@@ -313,9 +316,9 @@ class FlatlandRemoteEvaluationService:
             )
             if self.visualize:
                 if self.env_renderer:
-                    del self.env_renderer     
+                    del self.env_renderer
                 self.env_renderer = RenderTool(self.env, gl="PILSVG", )
-            
+
             # Set max episode steps allowed
             self.env._max_episode_steps = \
                 int(1.5 * (self.env.width + self.env.height))
@@ -323,7 +326,7 @@ class FlatlandRemoteEvaluationService:
             if self.begin_simulation:
                 # If begin simulation has already been initialized 
                 # atleast once
-                self.simulation_times.append(time.time()-self.begin_simulation)
+                self.simulation_times.append(time.time() - self.begin_simulation)
             self.begin_simulation = time.time()
 
             self.simulation_rewards.append(0)
@@ -348,15 +351,15 @@ class FlatlandRemoteEvaluationService:
             _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
             _command_response['payload'] = {}
             _command_response['payload']['observation'] = False
-            _command_response['payload']['env_file_path'] = False            
+            _command_response['payload']['env_file_path'] = False
 
         self.send_response(_command_response, command)
         #####################################################################
         # Update evaluation state
         #####################################################################
         progress = np.clip(
-                    self.simulation_count * 1.0 / len(self.env_file_paths),
-                    0, 1)
+            self.simulation_count * 1.0 / len(self.env_file_paths),
+            0, 1)
         mean_reward = round(np.mean(self.simulation_rewards), 2)
         mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2)
         mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3)
@@ -399,9 +402,9 @@ class FlatlandRemoteEvaluationService:
         """
         self.simulation_rewards_normalized[-1] += \
             cumulative_reward / (
-                        self.env._max_episode_steps + 
-                        self.env.get_num_agents()
-                    )
+                self.env._max_episode_steps +
+                self.env.get_num_agents()
+            )
 
         if done["__all__"]:
             # Compute percentage complete
@@ -412,14 +415,14 @@ class FlatlandRemoteEvaluationService:
                     complete += 1
             percentage_complete = complete * 1.0 / self.env.get_num_agents()
             self.simulation_percentage_complete[-1] = percentage_complete
-        
+
         # Record Frame
         if self.visualize:
             self.env_renderer.render_env(
-                                show=False, 
-                                show_observations=False, 
-                                show_predictions=False
-                                )
+                show=False,
+                show_observations=False,
+                show_predictions=False
+            )
             """
             Only save the frames for environments which are separately provided 
             in video_generation_indices param
@@ -427,10 +430,10 @@ class FlatlandRemoteEvaluationService:
             current_env_path = self.env_file_paths[self.simulation_count]
             if current_env_path in self.video_generation_envs:
                 self.env_renderer.gl.save_image(
-                        os.path.join(
-                            self.vizualization_folder_name,
-                            "flatland_frame_{:04d}.png".format(self.record_frame_step)
-                        ))
+                    os.path.join(
+                        self.vizualization_folder_name,
+                        "flatland_frame_{:04d}.png".format(self.record_frame_step)
+                    ))
                 self.record_frame_step += 1
 
         # Build and send response
@@ -453,7 +456,7 @@ class FlatlandRemoteEvaluationService:
         _payload = command['payload']
 
         # Register simulation time of the last episode
-        self.simulation_times.append(time.time()-self.begin_simulation)
+        self.simulation_times.append(time.time() - self.begin_simulation)
 
         if len(self.simulation_rewards) != len(self.env_file_paths):
             raise Exception(
@@ -461,7 +464,7 @@ class FlatlandRemoteEvaluationService:
                 to operate on all the test environments.
                 """
             )
-        
+
         mean_reward = round(np.mean(self.simulation_rewards), 2)
         mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2)
         mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3)
@@ -473,7 +476,7 @@ class FlatlandRemoteEvaluationService:
             # install it by : 
             #
             # conda install -c conda-forge x264 ffmpeg
-            
+
             print("Generating Video from thumbnails...")
             video_output_path, video_thumb_output_path = \
                 aicrowd_helpers.generate_movie_from_frames(
@@ -518,14 +521,14 @@ class FlatlandRemoteEvaluationService:
         self.evaluation_state["score"]["score_secondary"] = mean_reward
         self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
         self.handle_aicrowd_success_event(self.evaluation_state)
-        print("#"*100)
+        print("#" * 100)
         print("EVALUATION COMPLETE !!")
-        print("#"*100)
+        print("#" * 100)
         print("# Mean Reward : {}".format(mean_reward))
         print("# Mean Normalized Reward : {}".format(mean_normalized_reward))
         print("# Mean Percentage Complete : {}".format(mean_percentage_complete))
-        print("#"*100)
-        print("#"*100)
+        print("#" * 100)
+        print("#" * 100)
 
     def report_error(self, error_message, command_response_channel):
         """
@@ -536,16 +539,16 @@ class FlatlandRemoteEvaluationService:
         _command_response['type'] = messages.FLATLAND_RL.ERROR
         _command_response['payload'] = error_message
         _redis.rpush(
-            command_response_channel, 
+            command_response_channel,
             msgpack.packb(
-                _command_response, 
-                default=m.encode, 
+                _command_response,
+                default=m.encode,
                 use_bin_type=True)
-            )
+        )
         self.evaluation_state["state"] = "ERROR"
         self.evaluation_state["error"] = error_message
         self.handle_aicrowd_error_event(self.evaluation_state)
-    
+
     def handle_aicrowd_info_event(self, payload):
         self.oracle_events.register_event(
             event_type=self.oracle_events.CROWDAI_EVENT_INFO,
@@ -577,17 +580,17 @@ class FlatlandRemoteEvaluationService:
                 print("Self.Reward : ", self.reward)
                 print("Current Simulation : ", self.simulation_count)
                 if self.env_file_paths and \
-                        self.simulation_count < len(self.env_file_paths):
+                    self.simulation_count < len(self.env_file_paths):
                     print("Current Env Path : ",
-                        self.env_file_paths[self.simulation_count])
+                          self.env_file_paths[self.simulation_count])
 
-            try:                
+            try:
                 if command['type'] == messages.FLATLAND_RL.PING:
                     """
                         INITIAL HANDSHAKE : Respond with PONG
                     """
                     self.handle_ping(command)
-                
+
                 elif command['type'] == messages.FLATLAND_RL.ENV_CREATE:
                     """
                         ENV_CREATE
@@ -612,8 +615,8 @@ class FlatlandRemoteEvaluationService:
                     self.handle_env_submit(command)
                 else:
                     _error = self._error_template(
-                                    "UNKNOWN_REQUEST:{}".format(
-                                        str(command)))
+                        "UNKNOWN_REQUEST:{}".format(
+                            str(command)))
                     if self.verbose:
                         print("Responding with : ", _error)
                     self.report_error(
@@ -631,10 +634,11 @@ class FlatlandRemoteEvaluationService:
 
 if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser(description='Submit the result to AIcrowd')
-    parser.add_argument('--service_id', 
-                        dest='service_id', 
-                        default='FLATLAND_RL_SERVICE_ID', 
+    parser.add_argument('--service_id',
+                        dest='service_id',
+                        default='FLATLAND_RL_SERVICE_ID',
                         required=False)
     parser.add_argument('--test_folder',
                         dest='test_folder',
@@ -642,16 +646,16 @@ if __name__ == "__main__":
                         help="Folder containing the files for the test envs",
                         required=False)
     args = parser.parse_args()
-    
+
     test_folder = args.test_folder
 
     grader = FlatlandRemoteEvaluationService(
-                test_env_folder=test_folder,
-                flatland_rl_service_id=args.service_id,
-                verbose=True,
-                visualize=True,
-                video_generation_envs=["Test_0/Level_1.pkl"]
-                )
+        test_env_folder=test_folder,
+        flatland_rl_service_id=args.service_id,
+        verbose=True,
+        visualize=True,
+        video_generation_envs=["Test_0/Level_1.pkl"]
+    )
     result = grader.run()
     if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
         cumulative_results = result['payload']
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 69be59ae2a957f6a2aaa948d9830472d7824516a..af1aad222919b00b716dd9da0f3be9534d54e411 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -11,9 +11,9 @@ from numpy import array
 import flatland.utils.rendertools as rt
 from flatland.core.grid.grid4_utils import mirror
 from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
-from flatland.envs.generators import complex_rail_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, random_rail_generator
+from flatland.envs.rail_generators import complex_rail_generator, empty_rail_generator
 
 
 class EditorMVC(object):
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index f86f301cdda4676855ad7a52623822f3053d6ea1..92a0f84f35fa942b03236c6add6e722475a2d842 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -4,7 +4,7 @@ import time
 import tkinter as tk
 
 import numpy as np
-from PIL import Image, ImageDraw, ImageTk  # , ImageFont
+from PIL import Image, ImageDraw, ImageTk, ImageFont
 from numpy import array
 from pkg_resources import resource_string as resource_bytes
 
@@ -41,7 +41,7 @@ class PILGL(GraphicsLayer):
     SELECTED_AGENT_LAYER = 4
     SELECTED_TARGET_LAYER = 5
 
-    def __init__(self, width, height, jupyter=False):
+    def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
         self.yxBase = (0, 0)
         self.linewidth = 4
         self.n_agent_colors = 1  # overridden in loadAgent
@@ -52,13 +52,13 @@ class PILGL(GraphicsLayer):
         self.background_grid = np.zeros(shape=(self.width, self.height))
 
         if jupyter is False:
-            # NOTE: Currently removed the dependency on 
-            #       screeninfo. We have to find an alternate 
+            # NOTE: Currently removed the dependency on
+            #       screeninfo. We have to find an alternate
             #       way to compute the screen width and height
-            #       In the meantime, we are harcoding the 800x600 
+            #       In the meantime, we are harcoding the 800x600
             #       assumption
-            self.screen_width = 800
-            self.screen_height = 600
+            self.screen_width = screen_width
+            self.screen_height = screen_height
             w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth)
             h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth)
             self.nPixCell = int(max(1, np.ceil(min(w, h))))
@@ -90,6 +90,8 @@ class PILGL(GraphicsLayer):
         self.old_background_image = (None, None, None)
         self.create_layers()
 
+        self.font = ImageFont.load_default()
+
     def build_background_map(self, dTargets):
         x = self.old_background_image
         rebuild = False
@@ -114,7 +116,7 @@ class PILGL(GraphicsLayer):
                     for rc in dTargets:
                         r = rc[1]
                         c = rc[0]
-                        d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)))
+                        d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)) / 0.5)
                         distance = min(d, distance)
                     self.background_grid[x][y] = distance
 
@@ -167,8 +169,14 @@ class PILGL(GraphicsLayer):
         # quit but not destroy!
         self.__class__.window.quit()
 
-    def text(self, *args, **kwargs):
-        pass
+    def text(self, xPx, yPx, strText, layer=RAIL_LAYER):
+        xyPixLeftTop = (xPx, yPx)
+        self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
+
+    def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
+        print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer)
+        xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
+        self.text(*xyPixLeftTop, strText, layer)
 
     def prettify(self, *args, **kwargs):
         pass
@@ -263,9 +271,9 @@ class PILGL(GraphicsLayer):
 
 
 class PILSVG(PILGL):
-    def __init__(self, width, height, jupyter=False):
+    def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
         oSuper = super()
-        oSuper.__init__(width, height, jupyter)
+        oSuper.__init__(width, height, jupyter, screen_width, screen_height)
 
         self.lwAgents = []
         self.agents_prev = []
@@ -444,7 +452,7 @@ class PILSVG(PILGL):
 
         for transition, file in file_directory.items():
 
-            # Translate the ascii transition description in the format  "NE WS" to the 
+            # Translate the ascii transition description in the format  "NE WS" to the
             # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
             transition_16_bit = ["0"] * 16
             for sTran in transition.split(" "):
@@ -492,13 +500,17 @@ class PILSVG(PILGL):
                                           False)[0]
         self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER)
 
-    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None):
+    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None,
+                    show_debug=True):
+
         if binary_trans in self.pil_rail:
             pil_track = self.pil_rail[binary_trans]
             if target is not None:
                 target_img = self.station_colors[target % len(self.station_colors)]
                 target_img = Image.alpha_composite(pil_track, target_img)
                 self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER)
+                if show_debug:
+                    self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
 
             if binary_trans == 0:
                 if self.background_grid[col][row] <= 4:
@@ -579,7 +591,7 @@ class PILSVG(PILGL):
                 for color_idx, pil_zug_3 in enumerate(pils):
                     self.pil_zug[(in_direction_2, out_direction_2, color_idx)] = pils[color_idx]
 
-    def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected):
+    def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected, show_debug=False):
         delta_dir = (out_direction - in_direction) % 4
         color_idx = agent_idx % self.n_agent_colors
         # when flipping direction at a dead end, use the "out_direction" direction.
@@ -593,6 +605,10 @@ class PILSVG(PILGL):
             self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0)
             self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER)
 
+        if show_debug:
+            print("Call text:")
+            self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
+
     def set_cell_occupied(self, agent_idx, row, col):
         occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
         self.draw_image_row_col(occupied_im, (row, col), 1)
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 8974126ad69e35edbd621ba634727120720c9506..802b361b623cdaea08271f5748ac86194056bdf2 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -39,7 +39,10 @@ class RenderTool(object):
     theta = np.linspace(0, np.pi / 2, 5)
     arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]
 
-    def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND):
+    def __init__(self, env, gl="PILSVG", jupyter=False,
+                 agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                 show_debug=False, screen_width=800, screen_height=600):
+
         self.env = env
         self.frame_nr = 0
         self.start_time = time.time()
@@ -48,14 +51,15 @@ class RenderTool(object):
         self.agent_render_variant = agent_render_variant
 
         if gl == "PIL":
-            self.gl = PILGL(env.width, env.height, jupyter)
+            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
         elif gl == "PILSVG":
-            self.gl = PILSVG(env.width, env.height, jupyter)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
         else:
             print("[", gl, "] not found, switch to PILSVG")
-            self.gl = PILSVG(env.width, env.height, jupyter)
+            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
 
         self.new_rail = True
+        self.show_debug = show_debug
         self.update_background()
 
     def reset(self):
@@ -282,7 +286,7 @@ class RenderTool(object):
         if len(observation_dict) < 1:
             warnings.warn(
                 "Predictor did not provide any predicted cells to render. \
-                Observaiton builder needs to populate: env.dev_obs_dict")
+                Observation builder needs to populate: env.dev_obs_dict")
         else:
             for agent in agent_handles:
                 color = self.gl.get_agent_color(agent)
@@ -525,7 +529,7 @@ class RenderTool(object):
                         is_selected = False
 
                     self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
-                                        rail_grid=env.rail.grid)
+                                        rail_grid=env.rail.grid, show_debug=self.show_debug)
 
             self.gl.build_background_map(targets)
 
@@ -550,7 +554,8 @@ class RenderTool(object):
                 # set_agent_at uses the agent index for the color
                 if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
                     self.gl.set_cell_occupied(agent_idx, *(agent.position))
-                self.gl.set_agent_at(agent_idx, *position, old_direction, direction, selected_agent == agent_idx)
+                self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
+                                     selected_agent == agent_idx, show_debug=self.show_debug)
             else:
                 position = agent.position
                 direction = agent.direction
@@ -562,7 +567,7 @@ class RenderTool(object):
 
                         # set_agent_at uses the agent index for the color
                         self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
-                                             selected_agent == agent_idx)
+                                             selected_agent == agent_idx, show_debug=self.show_debug)
 
                 # set_agent_at uses the agent index for the color
                 if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py
index 28978ca3e5b958a51c578f6be5e8c87b77baaa97..67bd93dd35c8f53ef3cdef23dbae0f0d785b9a64 100644
--- a/flatland/utils/simple_rail.py
+++ b/flatland/utils/simple_rail.py
@@ -2,11 +2,124 @@ from typing import Tuple
 
 import numpy as np
 
-from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 
 
 def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _ _  _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [horizontal_straight] * 2 + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
+
+def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _ _  _  _ _ _
+    #               \
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [horizontal_straight] * 2 + [simple_switch_west_east_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
+def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _  _ _ _  _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6]  +
+        [[empty] * 3 + [dead_end_from_north] + [empty] * 6]  +
+        [[dead_end_from_east] + [horizontal_straight]  * 5 + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
+
+def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
     #        |
@@ -16,15 +129,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     #                |
     #                |
     #                |
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+
     empty = cells[0]
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
diff --git a/notebooks/simple_example1_env_from_tuple.ipynb b/notebooks/simple_example1_env_from_tuple.ipynb
index 3fd55bc8fafabdd57eb43e012c5d98b37c73d496..0fcfe26325e5778cbbfdf66591346972edb2406d 100644
--- a/notebooks/simple_example1_env_from_tuple.ipynb
+++ b/notebooks/simple_example1_env_from_tuple.ipynb
@@ -14,7 +14,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from flatland.envs.generators import rail_from_manual_specifications_generator\n",
+    "from flatland.envs.rail_generators import rail_from_manual_specifications_generator\n",
     "from flatland.envs.observations import TreeObsForRailEnv\n",
     "from flatland.envs.rail_env import RailEnv\n",
     "from flatland.utils.rendertools import RenderTool\n",
diff --git a/notebooks/simple_example2_generate_random_rail.ipynb b/notebooks/simple_example2_generate_random_rail.ipynb
index b9d4a96c02d4cb63392654025c6857c8b6764d1a..19b854ee15d8dd1e19361f58a552eba617b19b67 100644
--- a/notebooks/simple_example2_generate_random_rail.ipynb
+++ b/notebooks/simple_example2_generate_random_rail.ipynb
@@ -15,7 +15,7 @@
    "source": [
     "import random\n",
     "import numpy as np\n",
-    "from flatland.envs.generators import random_rail_generator\n",
+    "from flatland.envs.rail_generators import random_rail_generator\n",
     "from flatland.envs.observations import TreeObsForRailEnv\n",
     "from flatland.envs.rail_env import RailEnv\n",
     "from flatland.utils.rendertools import RenderTool\n",
diff --git a/notebooks/simple_example_3_manual_control.ipynb b/notebooks/simple_example_3_manual_control.ipynb
index 50f228055b320c15e0411c2b086254a1f4d4ceef..cb2b377765f375e8d77d8374fdcc3bb67ce06444 100644
--- a/notebooks/simple_example_3_manual_control.ipynb
+++ b/notebooks/simple_example_3_manual_control.ipynb
@@ -40,7 +40,7 @@
     "import random\n",
     "import numpy as np\n",
     "import time\n",
-    "from flatland.envs.generators import random_rail_generator\n",
+    "from flatland.envs.rail_generators import random_rail_generator\n",
     "from flatland.envs.observations import TreeObsForRailEnv\n",
     "from flatland.envs.rail_env import RailEnv\n",
     "from flatland.utils.rendertools import RenderTool"
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 12e0c092a37a475ab6e7dde21c665778e06f5e59..e5e89f76428bb881d0f72aa60aada97ab02167a5 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -1,25 +1,19 @@
 import numpy as np
 
-from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
 
 
 def test_walker():
     # _ _ _
 
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
     dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
@@ -34,6 +28,7 @@ def test_walker():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2,
                                                        predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index a414231619cfa924c2d33776f9f140cf88280517..8812c847e61d81f6614f37d26489b4c17ea7fd14 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -1,6 +1,13 @@
 from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
 from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
 
 
 def test_grid4_get_transitions():
@@ -43,4 +50,111 @@ def test_grid8_set_transitions():
     grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
     assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
 
-# TODO GridTransitionMap
+
+def check_path(env, rail, position, direction, target, expected, rendering=False):
+    agent = env.agents_static[0]
+    agent.position = position  # south dead-end
+    agent.direction = direction  # north
+    agent.target = target  # east dead-end
+    agent.moving = True
+    # reset to set agents from agents_static
+    # env.reset(False, False)
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.render_env(show=True, show_observations=False)
+        input("Continue?")
+    assert rail.check_path_exists(agent.position, agent.direction, agent.target) == expected
+
+
+def test_path_exists(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # reset to initialize agents_static
+    env.reset()
+
+    check_path(
+        env,
+        rail,
+        (5, 6),  # north of south dead-end
+        0,  # north
+        (3, 9),  # east dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (6, 6),  # south dead-end
+        2,  # south
+        (3, 9),  # east dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (3, 0),  # east dead-end
+        3,  # west
+        (0, 3),  # north dead-end
+        True
+    )
+    check_path(
+        env,
+        rail,
+        (5, 6),  # east dead-end
+        0,  # west
+        (1, 3),  # north dead-end
+        True)
+
+    check_path(
+        env,
+        rail,
+        (1,3),  # east dead-end
+        2,  # south
+        (3,3),  # north dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (1,3),  # east dead-end
+        0,  # north
+        (3,3),  # north dead-end
+        True
+    )
+
+
+def test_path_not_exists(rendering=False):
+    rail, rail_map = make_simple_rail_unconnected()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # reset to initialize agents_static
+    env.reset()
+
+    check_path(
+        env,
+        rail,
+        (5, 6),  # south dead-end
+        0,  # north
+        (0, 3),  # north dead-end
+        False
+    )
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.render_env(show=True, show_observations=False)
+        input("Continue?")
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 574705c49501415f37149bd4b3d870665bf06e60..c96e8db00fe721f42667aed4833d034a47f19156 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -5,10 +5,11 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.agent_utils import EnvAgent
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail
 
@@ -21,6 +22,7 @@ def test_global_obs():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
 
@@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 7acd58ed0337745f645db6dcc24a70ecb0b64305..09f7e5e67a15c55b5070ac8679e43ecc9a14b9da 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -5,22 +5,24 @@ import pprint
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
-from flatland.utils.simple_rail import make_simple_rail
+from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
 
 """Test predictions for `flatland` package."""
 
 
 def test_dummy_predictor(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map = make_simple_rail2()
 
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
                   )
@@ -89,7 +91,7 @@ def test_dummy_predictor(rendering=False):
     expected_actions = np.array([[0.],
                                  [2.],
                                  [2.],
-                                 [1.],
+                                 [2.],
                                  [2.],
                                  [2.],
                                  [2.],
@@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -226,10 +229,11 @@ def test_shortest_path_predictor(rendering=False):
 
 
 def test_shortest_path_predictor_conflicts(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map = make_invalid_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 71dc87ceddde986be763491d28dd2b70673632f4..d5dc3ac7af4be6ebd8c5cbeaf705bb710d36d138 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -2,15 +2,15 @@
 # -*- coding: utf-8 -*-
 import numpy as np
 
-from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.generators import complex_rail_generator
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
 
 """Tests for `flatland` package."""
 
@@ -27,6 +27,7 @@ def test_load_env():
 def test_save_load():
     env = RailEnv(width=10, height=10,
                   rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=2)
     env.reset()
     agent_1_pos = env.agents_static[0].position
@@ -49,15 +50,6 @@ def test_save_load():
 
 
 def test_rail_environment_single_agent():
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-
     # We instantiate the following map on a 3x3 grid
     #  _  _
     # / \/ \
@@ -65,6 +57,7 @@ def test_rail_environment_single_agent():
     # \_/\_/
 
     transitions = RailEnvTransitions()
+    cells = transitions.transition_list
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
@@ -86,6 +79,7 @@ def test_rail_environment_single_agent():
     rail_env = RailEnv(width=3,
                        height=3,
                        rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
@@ -139,7 +133,7 @@ test_rail_environment_single_agent()
 
 
 def test_dead_end():
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
 
     straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
     straight_horizontal = transitions.rotate_transition(straight_vertical,
@@ -165,6 +159,7 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
@@ -209,6 +204,7 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..db7cac61f4cf3bec4a330694c1864ef7d82bd076
--- /dev/null
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -0,0 +1,26 @@
+from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+
+
+def test_sparse_rail_generator():
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                                       num_intersections=10,  # Number of interesections in map
+                                                       num_trainstations=50,  # Number of possible start/targets on map
+                                                       min_node_dist=6,  # Minimal distance of nodes
+                                                       node_radius=3,  # Proximity of stations to city center
+                                                       num_neighb=3,  # Number of connections to other cities
+                                                       seed=5,  # Random seed
+                                                       realistic_mode=False  # Ordered distribution of nodes
+                                                       ),
+                  schedule_generator=sparse_schedule_generator(),
+                  number_of_agents=10,
+                  obs_builder_object=GlobalObsForRailEnv())
+    # reset to initialize agents_static
+    env_renderer = RenderTool(env, gl="PILSVG", )
+    env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
+    env_renderer.gl.save_image("./sparse_generator_false.png")
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 67dcd25c0769e542fd9a03502c2a8c1b29333b2b..eaf782df3255ecfc6ebaa7078935f485497ed359 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,8 +1,9 @@
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 
 
 class SingleAgentNavigationObs(TreeObsForRailEnv):
@@ -62,6 +63,7 @@ def test_malfunction_process():
                   height=20,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                         seed=0),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=SingleAgentNavigationObs(),
                   stochastic_data=stochastic_data)
diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py
index ac5d7f4132b5c224206af3602b6c1341fe026d8b..8248c675995fc5c906e82d8650a5b619e7b038f2 100644
--- a/tests/test_flatland_utils_rendertools.py
+++ b/tests/test_flatland_utils_rendertools.py
@@ -11,9 +11,9 @@ from importlib_resources import path
 
 import flatland.utils.rendertools as rt
 import images.test
-from flatland.envs.generators import empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import empty_rail_generator
 
 
 def checkFrozenImage(oRT, sFileImage, resave=False):
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 8b5468716fa12d83a0546a9b6ff34f2488beace5..8de36c81e4a13c0b7e7e5e556ad79234503ad31a 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,10 +1,12 @@
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 
 np.random.seed(1)
 
+
 # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
 # Training on simple small tasks is the best way to get familiar with the environment
 #
@@ -46,6 +48,7 @@ def test_multi_speed_init():
                   height=50,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
                                                         seed=0),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=5)
     # Initialize the agent with the parameters corresponding to the environment and observation_builder
     agent = RandomAgent(218, 4)
@@ -66,7 +69,7 @@ def test_multi_speed_init():
     # Run episode
     for step in range(100):
 
-        # Chose an action for each agent in the environment
+        # Choose an action for each agent in the environment
         for a in range(env.get_num_agents()):
             action = agent.act(0)
             action_dict.update({a: action})
@@ -75,12 +78,11 @@ def test_multi_speed_init():
             assert old_pos[a] == env.agents[a].position
 
         # Environment step which returns the observations for all agents, their corresponding
-        # reward and whether their are done
+        # reward and whether they are done
         _, _, _, _ = env.step(action_dict)
 
         # Update old position whenever an agent was allowed to move
         for i_agent in range(env.get_num_agents()):
             if (step + 1) % (i_agent + 1) == 0:
-                print(step, i_agent, env.agents[a].position)
-
+                print(step, i_agent, env.agents[i_agent].position)
                 old_pos[i_agent] = env.agents[i_agent].position
diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff5ee56a308ce19559d079b716bde90ad65baf11
--- /dev/null
+++ b/tests/test_speed_classes.py
@@ -0,0 +1,36 @@
+"""Test speed initialization by a map of speeds and their corresponding ratios."""
+import numpy as np
+
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator
+
+
+def test_speed_initialization_helper():
+    np.random.seed(1)
+    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
+    actual_speeds = speed_initialization_helper(10, speed_ratio_map)
+
+    # seed makes speed_initialization_helper deterministic -> check generated speeds.
+    assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2]
+
+
+def test_rail_env_speed_intializer():
+    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
+
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
+                                                        seed=0),
+                  schedule_generator=complex_schedule_generator(),
+                  number_of_agents=10)
+    env.reset()
+    actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
+
+    expected_speed_set = set(speed_ratio_map.keys())
+
+    # check that the number of speeds generated is correct
+    assert len(actual_speeds) == env.get_num_agents()
+
+    # check that only the speeds defined are generated
+    assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds})
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index f97b071e6b33c099efa5af36766e159e57716443..8b0480c887a53ade155c28aa6199db3d32f19603 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -3,11 +3,13 @@
 
 import numpy as np
 
-from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
-    random_rail_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
+    random_rail_generator, empty_rail_generator
+from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \
+    schedule_from_file
 from flatland.utils.simple_rail import make_simple_rail
 
 
@@ -58,7 +60,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  schedule_generator=complex_schedule_generator()
                   )
     assert env.get_num_agents() == 2
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -69,7 +72,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  schedule_generator=complex_schedule_generator()
                   )
     assert env.get_num_agents() == 0
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -82,7 +86,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  schedule_generator=complex_schedule_generator()
                   )
     assert env.get_num_agents() == n_agents
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=n_agents
                   )
     nr_rail_elements = np.count_nonzero(env.rail.grid)
@@ -118,6 +124,7 @@ def tests_rail_from_file():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=3,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -130,6 +137,7 @@ def tests_rail_from_file():
     env = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
+                  schedule_generator=schedule_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -151,6 +159,7 @@ def tests_rail_from_file():
     env2 = RailEnv(width=rail_map.shape[1],
                    height=rail_map.shape[0],
                    rail_generator=rail_from_grid_transition_map(rail),
+                   schedule_generator=random_schedule_generator(),
                    number_of_agents=3,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -164,6 +173,7 @@ def tests_rail_from_file():
     env2 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name_2),
+                   schedule_generator=schedule_from_file(file_name_2),
                    number_of_agents=1,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -180,6 +190,7 @@ def tests_rail_from_file():
     env3 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name),
+                   schedule_generator=schedule_from_file(file_name),
                    number_of_agents=1,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -197,6 +208,7 @@ def tests_rail_from_file():
     env4 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name_2),
+                   schedule_generator=schedule_from_file(file_name_2),
                    number_of_agents=1,
                    obs_builder_object=TreeObsForRailEnv(max_depth=2),
                    )