diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py
index 44e4b534c2f2dfa63cab385d009c9afd92285f48..49e550b195ffb15e6554413069369378e80e5f82 100644
--- a/examples/complex_rail_benchmark.py
+++ b/examples/complex_rail_benchmark.py
@@ -3,8 +3,9 @@ import random
 
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 
 
 def run_benchmark():
@@ -15,6 +16,7 @@ def run_benchmark():
     # Example generate a random rail
     env = RailEnv(width=15, height=15,
                   rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=5)
 
     n_trials = 20
diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 723bb1102092c7d48bd938bbd60d0c5213ffecf6..8b1de6aa4e303469d30983d30333fbfda89c1d1e 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -5,10 +5,11 @@ import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.generators import random_rail_generator, complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
@@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder):
     Simplest observation builder. The object returns observation vectors with 5 identical components,
     all equal to the ID of the respective agent.
     """
+
     def __init__(self):
         self.observation_space = [5]
 
@@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
+
     def __init__(self):
         super().__init__(max_depth=0)
         self.observation_space = [3]
@@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 env = RailEnv(width=7,
               height=7,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=1,
               obs_builder_object=SingleAgentNavigationObs())
 
@@ -97,8 +101,8 @@ obs = env.reset()
 env_renderer = RenderTool(env, gl="PILSVG")
 env_renderer.render_env(show=True, frames=True, show_observations=True)
 for step in range(100):
-    action = np.argmax(obs[0])+1
-    obs, all_rewards, done, _ = env.step({0:action})
+    action = np.argmax(obs[0]) + 1
+    obs, all_rewards, done, _ = env.step({0: action})
     print("Rewards: ", all_rewards, "  [done=", done, "]")
     env_renderer.render_env(show=True, frames=True, show_observations=True)
     time.sleep(0.1)
@@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor)
 env = RailEnv(width=10,
               height=10,
               rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=3,
               obs_builder_object=CustomObsBuilder)
 
diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index 515d6c1b0469b7fbd9bad8cd82a40db7766f6219..04da66904fda1a58847a4acc510d7fc4e4e86887 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -1,30 +1,41 @@
 import random
+from typing import Any
 
 import numpy as np
 
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
+from flatland.envs.schedule_generators import ScheduleGenerator, ScheduleGeneratorProduct
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
 np.random.seed(100)
 
 
-def custom_rail_generator():
-    def generator(width, height, num_agents=0, num_resets=0):
+def custom_rail_generator() -> RailGenerator:
+    def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
         rail_array.fill(0)
         new_tran = rail_trans.set_transition(1, 1, 1, 1)
         print(new_tran)
+        rail_array[0, 0] = new_tran
+        rail_array[0, 1] = new_tran
+        return grid_map, None
+
+    return generator
+
+
+def custom_schedule_generator() -> ScheduleGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
         agents_positions = []
         agents_direction = []
         agents_target = []
-        rail_array[0, 0] = new_tran
-        rail_array[0, 1] = new_tran
-        return grid_map, agents_positions, agents_direction, agents_target
+        speeds = []
+        return agents_positions, agents_direction, agents_target, speeds
 
     return generator
 
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 2c0f814576caef84471d20c91dd92d23d4db02ac..50ea74b84ac9851e88e48bcd32b914e69bc7dd34 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,14 +3,16 @@ import time
 
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(1)
 np.random.seed(1)
 
+
 class SingleAgentNavigationObs(TreeObsForRailEnv):
     """
     We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
@@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
+
     def __init__(self):
         super().__init__(max_depth=0)
         self.observation_space = [3]
@@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 env = RailEnv(width=14,
               height=14,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=2,
               obs_builder_object=SingleAgentNavigationObs())
 
@@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False)
 for step in range(100):
     actions = {}
     for i in range(len(obs)):
-        actions[i] = np.argmax(obs[i])+1
+        actions[i] = np.argmax(obs[i]) + 1
 
-    if step%5 == 0:
+    if step % 5 == 0:
         print("Agent halts")
-        actions[0] = 4 # Halt
+        actions[0] = 4  # Halt
 
     obs, all_rewards, done, _ = env.step(actions)
     if env.agents[0].malfunction_data['malfunction'] > 0:
@@ -82,4 +86,3 @@ for step in range(100):
     if done["__all__"]:
         break
 env_renderer.close_window()
-
diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 916e50b20b10a02c43c5b1da8bc0728930b8c535..71a185c765bcab831e7b104124a164bcf2398b14 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -1,9 +1,10 @@
 import numpy as np
+from flatland.envs.rail_generators import sparse_rail_generator
 
-from flatland.envs.generators import sparse_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 np.random.seed(1)
@@ -31,6 +32,7 @@ env = RailEnv(width=20,
                                                    realistic_mode=True,
                                                    enhance_intersection=True
                                                    ),
+              schedule_generator=sparse_schedule_generator(),
               number_of_agents=5,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=TreeObservation)
@@ -75,7 +77,6 @@ class RandomAgent:
 # Set action space to 4 to remove stop action
 agent = RandomAgent(218, 4)
 
-
 # Empty dictionary for all agent action
 action_dict = dict()
 
diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py
index 7956c34fd4a5b94859a4b64441450afe2114133c..fbadbd657c36fa1dadf0bca65cff3e9cccd269ea 100644
--- a/examples/simple_example_1.py
+++ b/examples/simple_example_1.py
@@ -1,5 +1,5 @@
-from flatland.envs.generators import rail_from_manual_specifications_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_manual_specifications_generator
 from flatland.utils.rendertools import RenderTool
 
 # Example generate a rail given a manual specification,
diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py
index 994c7deda1569b77d4adac8a17fa9ebe14b27ef6..6db9ba5abbd0999ef3896e733516ed6b3e498bae 100644
--- a/examples/simple_example_2.py
+++ b/examples/simple_example_2.py
@@ -2,8 +2,8 @@ import random
 
 import numpy as np
 
-from flatland.envs.generators import random_rail_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import random_rail_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 5aa03d8f95a7079b708baea1e2ddce27e9a46554..6df6d4af3076b3d9659aadbd55296b667dc7d6db 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -2,9 +2,10 @@ import random
 
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 random.seed(1)
@@ -13,6 +14,7 @@ np.random.seed(1)
 env = RailEnv(width=7,
               height=7,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               number_of_agents=2,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
diff --git a/examples/training_example.py b/examples/training_example.py
index d125be1587a56025ba1cd3f78b28ba3976f01fbf..df93479f5a5ee05abfcb1a98b07ef052bffc2bd4 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -1,9 +1,10 @@
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 np.random.seed(1)
@@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
 env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              schedule_generator=complex_schedule_generator(),
               obs_builder_object=TreeObservation,
               number_of_agents=3)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
+
 # Import your own Agent or use RLlib to train agents on Flatland
 # As an example we use a random agent here
 
diff --git a/flatland/cli.py b/flatland/cli.py
index 32e8d9dc786b0412795694fc985c90aa55fc2e91..47c450dba803fac17bc13979663ef04e4c0db899 100644
--- a/flatland/cli.py
+++ b/flatland/cli.py
@@ -2,29 +2,33 @@
 
 """Console script for flatland."""
 import sys
+import time
+
 import click
 import numpy as np
-import time
-from flatland.envs.generators import complex_rail_generator
+import redis
+
 from flatland.envs.rail_env import RailEnv
-from flatland.utils.rendertools import RenderTool
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.evaluators.service import FlatlandRemoteEvaluationService
-import redis
+from flatland.utils.rendertools import RenderTool
 
 
 @click.command()
 def demo(args=None):
     """Demo script to check installation"""
     env = RailEnv(
-            width=15,
-            height=15,
-            rail_generator=complex_rail_generator(
-                                    nr_start_goal=10,
-                                    nr_extra=1,
-                                    min_dist=8,
-                                    max_dist=99999),
-            number_of_agents=5)
-    
+        width=15,
+        height=15,
+        rail_generator=complex_rail_generator(
+            nr_start_goal=10,
+            nr_extra=1,
+            min_dist=8,
+            max_dist=99999),
+        schedule_generator=complex_schedule_generator(),
+        number_of_agents=5)
+
     env._max_episode_steps = int(15 * (env.width + env.height))
     env_renderer = RenderTool(env)
 
@@ -52,12 +56,12 @@ def demo(args=None):
 
 
 @click.command()
-@click.option('--tests', 
+@click.option('--tests',
               type=click.Path(exists=True),
               help="Path to folder containing Flatland tests",
               required=True
               )
-@click.option('--service_id', 
+@click.option('--service_id',
               default="FLATLAND_RL_SERVICE_ID",
               help="Evaluation Service ID. This has to match the service id on the client.",
               required=False
@@ -70,14 +74,14 @@ def evaluator(tests, service_id):
         raise Exception(
             "\nRedis server does not seem to be running on your localhost.\n"
             "Please ensure that you have a redis server running on your localhost"
-            )
-    
+        )
+
     grader = FlatlandRemoteEvaluationService(
-                test_env_folder=tests,
-                flatland_rl_service_id=service_id,
-                visualize=False,
-                verbose=False
-                )
+        test_env_folder=tests,
+        flatland_rl_service_id=service_id,
+        visualize=False,
+        verbose=False
+    )
     grader.run()
 
 
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index 0055b243668a4f3cd562958a59f52ba830af1c86..996bd73ad9de598eb162a937c135681675119ad3 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu
 a GridTransitionMap object.
 """
 
-import numpy as np
-
 from flatland.core.grid.grid4_astar import a_star
-from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position
+from flatland.core.grid.grid4_utils import get_direction, mirror
 
 
 def connect_rail(rail_trans, rail_array, start, end):
@@ -195,81 +193,3 @@ def connect_to_nodes(rail_trans, rail_array, start, end):
 
         current_dir = new_dir
     return path
-
-
-def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
-    """
-    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
-
-    TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
-    """
-
-    def _path_exists(rail, start, direction, end):
-        # BFS - Check if a path exists between the 2 nodes
-
-        visited = set()
-        stack = [(start, direction)]
-        while stack:
-            node = stack.pop()
-            if node[0][0] == end[0] and node[0][1] == end[1]:
-                return 1
-            if node not in visited:
-                visited.add(node)
-                moves = rail.get_transitions(node[0][0], node[0][1], node[1])
-                for move_index in range(4):
-                    if moves[move_index]:
-                        stack.append((get_new_position(node[0], move_index),
-                                      move_index))
-
-                # If cell is a dead-end, append previous node with reversed
-                # orientation!
-                nbits = 0
-                tmp = rail.get_full_transitions(node[0][0], node[0][1])
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    stack.append((node[0], (node[1] + 2) % 4))
-
-        return 0
-
-    valid_positions = []
-    for r in range(rail.height):
-        for c in range(rail.width):
-            if rail.get_full_transitions(r, c) > 0:
-                valid_positions.append((r, c))
-
-    re_generate = True
-    while re_generate:
-        agents_position = [
-            valid_positions[i] for i in
-            np.random.choice(len(valid_positions), num_agents)]
-        agents_target = [
-            valid_positions[i] for i in
-            np.random.choice(len(valid_positions), num_agents)]
-
-        # agents_direction must be a direction for which a solution is
-        # guaranteed.
-        agents_direction = [0] * num_agents
-        re_generate = False
-        for i in range(num_agents):
-            valid_movements = []
-            for direction in range(4):
-                position = agents_position[i]
-                moves = rail.get_transitions(position[0], position[1], direction)
-                for move_index in range(4):
-                    if moves[move_index]:
-                        valid_movements.append((direction, move_index))
-
-            valid_starting_directions = []
-            for m in valid_movements:
-                new_position = get_new_position(agents_position[i], m[1])
-                if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
-                    valid_starting_directions.append(m[0])
-
-            if len(valid_starting_directions) == 0:
-                re_generate = True
-            else:
-                agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
-
-    return agents_position, agents_direction, agents_target
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 6e6665af88c9e8b31e1a689815edb7aaada342f9..a61ef02207174d04489b5311dc042b7c06db1412 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -13,8 +13,9 @@ from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
-from flatland.envs.generators import random_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.rail_generators import random_rail_generator, RailGenerator
+from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
 
 m.patch()
 
@@ -92,7 +93,8 @@ class RailEnv(Environment):
     def __init__(self,
                  width,
                  height,
-                 rail_generator=random_rail_generator(),
+                 rail_generator: RailGenerator = random_rail_generator(),
+                 schedule_generator: ScheduleGenerator = random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
                  max_episode_steps=None,
@@ -108,13 +110,12 @@ class RailEnv(Environment):
             height and agents handles of a  rail environment, along with the number of times
             the env has been reset, and returns a GridTransitionMap object and a list of
             starting positions, targets, and initial orientations for agent handle.
-            Implemented functions are:
-                random_rail_generator : generate a random rail of given size
-                rail_from_grid_transition_map(rail_map) : generate a rail from
-                                        a GridTransitionMap object
-                rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
-                                        a rail specifications array
-                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
+            The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
+            Implementations can be found in flatland/envs/rail_generators.py
+        schedule_generator : function
+            The schedule_generator function is a function that takes the grid, the number of agents and optional hints
+            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
+            Implementations can be found in flatland/envs/schedule_generators.py
         width : int
             The width of the rail map. Potentially in the future,
             a range of widths to sample from.
@@ -132,6 +133,8 @@ class RailEnv(Environment):
         file_name: you can load a pickle file.
         """
 
+        self.rail_generator: RailGenerator = rail_generator
+        self.schedule_generator: ScheduleGenerator = schedule_generator
         self.rail_generator = rail_generator
         self.rail: GridTransitionMap = None
         self.width = width
@@ -214,14 +217,13 @@ class RailEnv(Environment):
             if replace_agents then regenerate the agents static.
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
         """
-        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
+        rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
-        # Check if generator provided a distance map TODO: Make this check safer!
-        if len(tRailAgents) > 5:
-            self.obs_builder.distance_map = tRailAgents[-1]
+        if optionals and 'distance_maps' in optionals:
+            self.obs_builder.distance_map = optionals['distance_maps']
 
         if regen_rail or self.rail is None:
-            self.rail = tRailAgents[0]
+            self.rail = rail
             self.height, self.width = self.rail.grid.shape
             for r in range(self.height):
                 for c in range(self.width):
@@ -231,7 +233,11 @@ class RailEnv(Environment):
                         warnings.warn("Invalid grid at {} -> {}".format(rcPos, check))
 
         if replace_agents:
-            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
+            agents_hints = None
+            if optionals and 'agents_hints' in optionals:
+                agents_hints = optionals['agents_hints']
+            self.agents_static = EnvAgentStatic.from_lists(
+                *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
 
         self.restart_agents()
 
diff --git a/flatland/envs/generators.py b/flatland/envs/rail_generators.py
similarity index 87%
rename from flatland/envs/generators.py
rename to flatland/envs/rail_generators.py
index 525db36e8c5a09a451c0c59c1d03f1352f66e827..ed507dca9de9b3e90d412e77a7204037a5a20975 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/rail_generators.py
@@ -1,4 +1,6 @@
+"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
 import warnings
+from typing import Callable, Tuple, Any, Optional
 
 import msgpack
 import numpy as np
@@ -7,29 +9,34 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import EnvAgentStatic
 from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
-from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
+RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
+RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 
-def empty_rail_generator():
+
+def empty_rail_generator() -> RailGenerator:
     """
     Returns a generator which returns an empty rail mail with no agents.
     Primarily used by the editor
     """
 
-    def generator(width, height, num_agents=0, num_resets=0):
+    def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
         rail_array.fill(0)
 
-        return grid_map, [], [], [], []
+        return grid_map, None
 
     return generator
 
 
-def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
+def complex_rail_generator(nr_start_goal=1,
+                           nr_extra=100,
+                           min_dist=20,
+                           max_dist=99999,
+                           seed=0) -> RailGenerator:
     """
     Parameters
     -------
@@ -49,8 +56,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         if num_agents > nr_start_goal:
             num_agents = nr_start_goal
             print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
-        rail_trans = RailEnvTransitions()
-        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions())
         rail_array = grid_map.grid
         rail_array.fill(0)
 
@@ -74,6 +80,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         # - return transition map + list of [start_pos, start_dir, goal_pos] points
         #
 
+        rail_trans = grid_map.transitions
         start_goal = []
         start_dir = []
         nr_created = 0
@@ -143,11 +150,10 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
             if len(new_path) >= 2:
                 nr_created += 1
 
-        agents_position = [sg[0] for sg in start_goal[:num_agents]]
-        agents_target = [sg[1] for sg in start_goal[:num_agents]]
-        agents_direction = start_dir[:num_agents]
-
-        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return grid_map, {'agents_hints': {
+            'start_goal': start_goal,
+            'start_dir': start_dir
+        }}
 
     return generator
 
@@ -191,22 +197,18 @@ def rail_from_manual_specifications_generator(rail_spec):
                 effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
                 rail.set_transitions((r, c), effective_transition_cell)
 
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            rail,
-            num_agents)
-
-        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return [rail, None]
 
     return generator
 
 
-def rail_from_file(filename):
+def rail_from_file(filename) -> RailGenerator:
     """
     Utility to load pickle file
 
     Parameters
     -------
-    input_file : Pickle file generated by env.save() or editor
+    filename : Pickle file generated by env.save() or editor
 
     Returns
     -------
@@ -224,26 +226,16 @@ def rail_from_file(filename):
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
-        # agents are always reset as not moving
-        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
-        # setup with loaded data
-        agents_position = [a.position for a in agents_static]
-        agents_direction = [a.direction for a in agents_static]
-        agents_target = [a.target for a in agents_static]
         if b"distance_maps" in data.keys():
             distance_maps = data[b"distance_maps"]
             if len(distance_maps) > 0:
-                return rail, agents_position, agents_direction, agents_target, [1.0] * len(
-                    agents_position), distance_maps
-            else:
-                return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
-        else:
-            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+                return rail, {'distance_maps': distance_maps}
+        return [rail, None]
 
     return generator
 
 
-def rail_from_grid_transition_map(rail_map):
+def rail_from_grid_transition_map(rail_map) -> RailGenerator:
     """
     Utility to convert a rail given by a GridTransitionMap map with the correct
     16-bit transitions specifications.
@@ -259,17 +251,13 @@ def rail_from_grid_transition_map(rail_map):
         Generator function that always returns the given `rail_map' object.
     """
 
-    def generator(width, height, num_agents, num_resets=0):
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            rail_map,
-            num_agents)
-
-        return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+    def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
+        return rail_map, None
 
     return generator
 
 
-def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
+def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator:
     """
     Dummy random level generator:
     - fill in cells at random in [width-2, height-2]
@@ -301,7 +289,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
         The matrix with the correct 16-bit bitmaps for each cell.
     """
 
-    def generator(width, height, num_agents, num_resets=0):
+    def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
         t_utils = RailEnvTransitions()
 
         transition_probability = cell_type_relative_proportion
@@ -533,11 +521,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
         return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
         return_rail.grid = tmp_rail
 
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            return_rail,
-            num_agents)
-
-        return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return return_rail, None
 
     return generator
 
@@ -802,48 +786,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             else:
                 num_agents -= 1
 
-        # Place agents and targets within available train stations
-        agents_position = []
-        agents_target = []
-        agents_direction = []
-
-        for agent_idx in range(num_agents):
-            # Set target for agent
-            current_target_node = agent_start_targets_nodes[agent_idx][1]
-            target_station_idx = np.random.randint(len(train_stations[current_target_node]))
-            target = train_stations[current_target_node][target_station_idx]
-            tries = 0
-            while (target[0], target[1]) in agents_target:
-                target_station_idx = np.random.randint(len(train_stations[current_target_node]))
-                target = train_stations[current_target_node][target_station_idx]
-                tries += 1
-                if tries > 100:
-                    warnings.warn("Could not set target position, removing an agent")
-                    break
-            agents_target.append((target[0], target[1]))
-
-            # Set start for agent
-            current_start_node = agent_start_targets_nodes[agent_idx][0]
-            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
-            start = train_stations[current_start_node][start_station_idx]
-            tries = 0
-            while (start[0], start[1]) in agents_position:
-                tries += 1
-                if tries > 100:
-                    warnings.warn("Could not set start position, please change initial parameters!!!!")
-                    break
-                start_station_idx = np.random.randint(len(train_stations[current_start_node]))
-                start = train_stations[current_start_node][start_station_idx]
-
-            agents_position.append((start[0], start[1]))
-
-            # Orient the agent correctly
-            for orientation in range(4):
-                transitions = grid_map.get_transitions(start[0], start[1], orientation)
-                if any(transitions) > 0:
-                    agents_direction.append(orientation)
-                    continue
-
-        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return grid_map, {'agents_hints': {
+            'agent_start_targets_nodes': agent_start_targets_nodes,
+            'train_stations': train_stations
+        }}
 
     return generator
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ebc6c71c17db308789a4baf0ec99729ec9991e8
--- /dev/null
+++ b/flatland/envs/schedule_generators.py
@@ -0,0 +1,238 @@
+"""Schedule generators (railway undertaking, "EVU")."""
+import warnings
+from typing import Tuple, List, Callable, Mapping, Optional, Any
+
+import msgpack
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgentStatic
+
+AgentPosition = Tuple[int, int]
+ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
+ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct]
+
+
+def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]:
+    """
+    Parameters
+    -------
+    nb_agents : int
+        The number of agents to generate a speed for
+    speed_ratio_map : Mapping[float,float]
+        A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
+
+    Returns
+    -------
+    List[float]
+        A list of size nb_agents of speeds with the corresponding probabilistic ratios.
+    """
+    if speed_ratio_map is None:
+        return [1.0] * nb_agents
+
+    nb_classes = len(speed_ratio_map.keys())
+    speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
+    speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
+    speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
+    return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
+
+
+def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        start_goal = hints['start_goal']
+        start_dir = hints['start_dir']
+        agents_position = [sg[0] for sg in start_goal[:num_agents]]
+        agents_target = [sg[1] for sg in start_goal[:num_agents]]
+        agents_direction = start_dir[:num_agents]
+
+        if speed_ratio_map:
+            speeds = speed_initialization_helper(num_agents, speed_ratio_map)
+        else:
+            speeds = [1.0] * len(agents_position)
+
+        return agents_position, agents_direction, agents_target, speeds
+
+    return generator
+
+
+def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        train_stations = hints['train_stations']
+        agent_start_targets_nodes = hints['agent_start_targets_nodes']
+        # Place agents and targets within available train stations
+        agents_position = []
+        agents_target = []
+        agents_direction = []
+        for agent_idx in range(num_agents):
+            # Set target for agent
+            current_target_node = agent_start_targets_nodes[agent_idx][1]
+            target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+            target = train_stations[current_target_node][target_station_idx]
+            tries = 0
+            while (target[0], target[1]) in agents_target:
+                target_station_idx = np.random.randint(len(train_stations[current_target_node]))
+                target = train_stations[current_target_node][target_station_idx]
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set target position, removing an agent")
+                    break
+            agents_target.append((target[0], target[1]))
+
+            # Set start for agent
+            current_start_node = agent_start_targets_nodes[agent_idx][0]
+            start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+            start = train_stations[current_start_node][start_station_idx]
+            tries = 0
+            while (start[0], start[1]) in agents_position:
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set start position, please change initial parameters!!!!")
+                    break
+                start_station_idx = np.random.randint(len(train_stations[current_start_node]))
+                start = train_stations[current_start_node][start_station_idx]
+
+            agents_position.append((start[0], start[1]))
+
+            # Orient the agent correctly
+            for orientation in range(4):
+                transitions = rail.get_transitions(start[0], start[1], orientation)
+                if any(transitions) > 0:
+                    agents_direction.append(orientation)
+                    continue
+
+        if speed_ratio_map:
+            speeds = speed_initialization_helper(num_agents, speed_ratio_map)
+        else:
+            speeds = [1.0] * len(agents_position)
+
+        return agents_position, agents_direction, agents_target, speeds
+
+    return generator
+
+
+def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
+    """
+    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
+
+    Parameters
+    -------
+        rail : GridTransitionMap
+            The railway to place agents on.
+        num_agents : int
+            The number of agents to generate a speed for
+        speed_ratio_map : Mapping[float,float]
+            A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
+    Returns
+    -------
+        Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
+        initial positions, directions, targets speeds
+    """
+
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
+        def _path_exists(rail, start, direction, end):
+            # BFS - Check if a path exists between the 2 nodes
+
+            visited = set()
+            stack = [(start, direction)]
+            while stack:
+                node = stack.pop()
+                if node[0][0] == end[0] and node[0][1] == end[1]:
+                    return 1
+                if node not in visited:
+                    visited.add(node)
+                    moves = rail.get_transitions(node[0][0], node[0][1], node[1])
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            stack.append((get_new_position(node[0], move_index),
+                                          move_index))
+
+                    # If cell is a dead-end, append previous node with reversed
+                    # orientation!
+                    nbits = 0
+                    tmp = rail.get_full_transitions(node[0][0], node[0][1])
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        stack.append((node[0], (node[1] + 2) % 4))
+
+            return 0
+
+        valid_positions = []
+        for r in range(rail.height):
+            for c in range(rail.width):
+                if rail.get_full_transitions(r, c) > 0:
+                    valid_positions.append((r, c))
+        if len(valid_positions) == 0:
+            return [], [], [], []
+        re_generate = True
+        while re_generate:
+            agents_position = [
+                valid_positions[i] for i in
+                np.random.choice(len(valid_positions), num_agents)]
+            agents_target = [
+                valid_positions[i] for i in
+                np.random.choice(len(valid_positions), num_agents)]
+
+            # agents_direction must be a direction for which a solution is
+            # guaranteed.
+            agents_direction = [0] * num_agents
+            re_generate = False
+            for i in range(num_agents):
+                valid_movements = []
+                for direction in range(4):
+                    position = agents_position[i]
+                    moves = rail.get_transitions(position[0], position[1], direction)
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            valid_movements.append((direction, move_index))
+
+                valid_starting_directions = []
+                for m in valid_movements:
+                    new_position = get_new_position(agents_position[i], m[1])
+                    if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0],
+                                                                              agents_target[i]):
+                        valid_starting_directions.append(m[0])
+
+                if len(valid_starting_directions) == 0:
+                    re_generate = True
+                else:
+                    agents_direction[i] = valid_starting_directions[
+                        np.random.choice(len(valid_starting_directions), 1)[0]]
+
+        agents_speed = speed_initialization_helper(num_agents, speed_ratio_map)
+        return agents_position, agents_direction, agents_target, agents_speed
+
+    return generator
+
+
+def agents_from_file(filename) -> ScheduleGenerator:
+    """
+    Utility to load pickle file
+
+    Parameters
+    -------
+    input_file : Pickle file generated by env.save() or editor
+
+    Returns
+    -------
+    Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
+        initial positions, directions, targets speeds
+    """
+
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+        data = msgpack.unpackb(load_data, use_list=False)
+
+        # agents are always reset as not moving
+        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
+        # setup with loaded data
+        agents_position = [a.position for a in agents_static]
+        agents_direction = [a.direction for a in agents_static]
+        agents_target = [a.target for a in agents_static]
+
+        return agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+
+    return generator
diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index a4968c0c8e827c060e0e3f7de0cf28cc0658089b..f2dac1c8705a651e2a7be026b4bd82a961efbbdd 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -1,18 +1,21 @@
-import redis
+import hashlib
 import json
+import logging
 import os
-import numpy as np
+import random
+import time
+
 import msgpack
 import msgpack_numpy as m
-import hashlib
-import random
-from flatland.evaluators import messages
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.generators import rail_from_file
+import numpy as np
+import redis
+
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-import time
-import logging
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_file
+from flatland.evaluators import messages
+
 logger = logging.getLogger(__name__)
 logger.setLevel(logging.INFO)
 m.patch()
@@ -22,8 +25,8 @@ def are_dicts_equal(d1, d2):
     """ return True if all keys and values are the same """
     return all(k in d2 and d1[k] == d2[k]
                for k in d1) \
-        and all(k in d1 and d1[k] == d2[k]
-               for k in d2)
+           and all(k in d1 and d1[k] == d2[k]
+                   for k in d2)
 
 
 class FlatlandRemoteClient(object):
@@ -41,39 +44,40 @@ class FlatlandRemoteClient(object):
         where `service_id` is either provided as an `env` variable or is
         instantiated to "flatland_rl_redis_service_id"
     """
-    def __init__(self,  
-                remote_host='127.0.0.1',
-                remote_port=6379,
-                remote_db=0,
-                remote_password=None,
-                test_envs_root=None,
-                verbose=False):
+
+    def __init__(self,
+                 remote_host='127.0.0.1',
+                 remote_port=6379,
+                 remote_db=0,
+                 remote_password=None,
+                 test_envs_root=None,
+                 verbose=False):
 
         self.remote_host = remote_host
         self.remote_port = remote_port
         self.remote_db = remote_db
         self.remote_password = remote_password
         self.redis_pool = redis.ConnectionPool(
-                                host=remote_host,
-                                port=remote_port,
-                                db=remote_db,
-                                password=remote_password)
+            host=remote_host,
+            port=remote_port,
+            db=remote_db,
+            password=remote_password)
         self.namespace = "flatland-rl"
         self.service_id = os.getenv(
-                            'FLATLAND_RL_SERVICE_ID',
-                            'FLATLAND_RL_SERVICE_ID'
-                            )
+            'FLATLAND_RL_SERVICE_ID',
+            'FLATLAND_RL_SERVICE_ID'
+        )
         self.command_channel = "{}::{}::commands".format(
-                                    self.namespace,
-                                    self.service_id
-                                )
+            self.namespace,
+            self.service_id
+        )
         if test_envs_root:
             self.test_envs_root = test_envs_root
         else:
             self.test_envs_root = os.getenv(
-                                'AICROWD_TESTS_FOLDER',
-                                '/tmp/flatland_envs'
-                                )
+                'AICROWD_TESTS_FOLDER',
+                '/tmp/flatland_envs'
+            )
 
         self.verbose = verbose
 
@@ -85,12 +89,12 @@ class FlatlandRemoteClient(object):
 
     def _generate_response_channel(self):
         random_hash = hashlib.md5(
-                        "{}".format(
-                                random.randint(0, 10**10)
-                            ).encode('utf-8')).hexdigest()
+            "{}".format(
+                random.randint(0, 10 ** 10)
+            ).encode('utf-8')).hexdigest()
         response_channel = "{}::{}::response::{}".format(self.namespace,
-                                                        self.service_id,
-                                                        random_hash)
+                                                         self.service_id,
+                                                         random_hash)
         return response_channel
 
     def _blocking_request(self, _request):
@@ -124,9 +128,9 @@ class FlatlandRemoteClient(object):
         if self.verbose:
             print("Response : ", _response)
         _response = msgpack.unpackb(
-                        _response, 
-                        object_hook=m.decode, 
-                        encoding="utf8")
+            _response,
+            object_hook=m.decode,
+            encoding="utf8")
         if _response['type'] == messages.FLATLAND_RL.ERROR:
             raise Exception(str(_response["payload"]))
         else:
@@ -181,7 +185,7 @@ class FlatlandRemoteClient(object):
                 "Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
                 "to point to the location of the Tests folder ? \n"
                 "We are currently looking at `{}` for the tests".format(self.test_envs_root)
-                )
+            )
         print("Current env path : ", test_env_file_path)
         self.env = RailEnv(
             width=1,
@@ -207,7 +211,7 @@ class FlatlandRemoteClient(object):
         _request['payload']['action'] = action
         _response = self._blocking_request(_request)
         _payload = _response['payload']
-        
+
         # remote_observation = _payload['observation']
         remote_reward = _payload['reward']
         remote_done = _payload['done']
@@ -216,14 +220,14 @@ class FlatlandRemoteClient(object):
         # Replicate the action in the local env
         local_observation, local_reward, local_done, local_info = \
             self.env.step(action)
-        
+
         print(local_reward)
         if not are_dicts_equal(remote_reward, local_reward):
             raise Exception("local and remote `reward` are diverging")
             print(remote_reward, local_reward)
         if not are_dicts_equal(remote_done, local_done):
             raise Exception("local and remote `done` are diverging")
-        
+
         # Return local_observation instead of remote_observation
         # as the remote_observation is build using a dummy observation
         # builder
@@ -250,21 +254,23 @@ class FlatlandRemoteClient(object):
 if __name__ == "__main__":
     remote_client = FlatlandRemoteClient()
 
+
     def my_controller(obs, _env):
         _action = {}
         for _idx, _ in enumerate(_env.agents):
             _action[_idx] = np.random.randint(0, 5)
         return _action
-    
+
+
     my_observation_builder = TreeObsForRailEnv(max_depth=3,
-                                predictor=ShortestPathPredictorForRailEnv())
+                                               predictor=ShortestPathPredictorForRailEnv())
 
     episode = 0
     obs = True
-    while obs:        
+    while obs:
         obs = remote_client.env_create(
-                    obs_builder_object=my_observation_builder
-                    )
+            obs_builder_object=my_observation_builder
+        )
         if not obs:
             """
             The remote env returns False as the first obs
@@ -285,7 +291,5 @@ if __name__ == "__main__":
                 print("Reward : ", sum(list(all_rewards.values())))
                 break
 
-    print("Evaluation Complete...")       
+    print("Evaluation Complete...")
     print(remote_client.submit())
-
-
diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py
index 3ad0a97598c8beb66fc164eb45b670c87f3c96f9..8967b52d9d6ee70a7eb8af257ef6b4e25b531314 100644
--- a/flatland/evaluators/service.py
+++ b/flatland/evaluators/service.py
@@ -1,24 +1,26 @@
 #!/usr/bin/env python
 from __future__ import print_function
-import redis
-from flatland.envs.generators import rail_from_file
-from flatland.envs.rail_env import RailEnv
-from flatland.core.env_observation_builder import DummyObservationBuilder
-from flatland.evaluators import messages
-from flatland.evaluators import aicrowd_helpers
-from flatland.utils.rendertools import RenderTool
-import numpy as np
-import msgpack
-import msgpack_numpy as m
-import os
+
 import glob
+import os
+import random
 import shutil
 import time
 import traceback
+
 import crowdai_api
+import msgpack
+import msgpack_numpy as m
+import numpy as np
+import redis
 import timeout_decorator
-import random
 
+from flatland.core.env_observation_builder import DummyObservationBuilder
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_file
+from flatland.evaluators import aicrowd_helpers
+from flatland.evaluators import messages
+from flatland.utils.rendertools import RenderTool
 
 use_signals_in_timeout = True
 if os.name == 'nt':
@@ -35,7 +37,7 @@ m.patch()
 ########################################################
 # CONSTANTS
 ########################################################
-PER_STEP_TIMEOUT = 10*60  # 5 minutes
+PER_STEP_TIMEOUT = 10 * 60  # 5 minutes
 
 
 class FlatlandRemoteEvaluationService:
@@ -59,17 +61,18 @@ class FlatlandRemoteEvaluationService:
     unpacked with `msgpack` (a patched version of msgpack which also supports
     numpy arrays).
     """
+
     def __init__(self,
-                test_env_folder="/tmp",
-                flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
-                remote_host='127.0.0.1',
-                remote_port=6379,
-                remote_db=0,
-                remote_password=None,
-                visualize=False,
-                video_generation_envs=[],
-                report=None,
-                verbose=False):
+                 test_env_folder="/tmp",
+                 flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
+                 remote_host='127.0.0.1',
+                 remote_port=6379,
+                 remote_db=0,
+                 remote_password=None,
+                 visualize=False,
+                 video_generation_envs=[],
+                 report=None,
+                 verbose=False):
 
         # Test Env folder Paths
         self.test_env_folder = test_env_folder
@@ -83,15 +86,15 @@ class FlatlandRemoteEvaluationService:
         # Logging and Reporting related vars
         self.verbose = verbose
         self.report = report
-        
+
         # Communication Protocol Related vars
         self.namespace = "flatland-rl"
         self.service_id = flatland_rl_service_id
         self.command_channel = "{}::{}::commands".format(
-                                    self.namespace, 
-                                    self.service_id
-                                )
-        
+            self.namespace,
+            self.service_id
+        )
+
         # Message Broker related vars
         self.remote_host = remote_host
         self.remote_port = remote_port
@@ -114,7 +117,7 @@ class FlatlandRemoteEvaluationService:
                 "normalized_reward": 0.0
             }
         }
-        
+
         # RailEnv specific variables
         self.env = False
         self.env_renderer = False
@@ -156,7 +159,7 @@ class FlatlandRemoteEvaluationService:
             ├── .......
             ├── .......
             └── Level_99.pkl 
-        """            
+        """
         env_paths = sorted(glob.glob(
             os.path.join(
                 self.test_env_folder,
@@ -179,16 +182,16 @@ class FlatlandRemoteEvaluationService:
         """
         if self.verbose or self.report:
             print("Attempting to connect to redis server at {}:{}/{}".format(
-                    self.remote_host, 
-                    self.remote_port, 
-                    self.remote_db))
+                self.remote_host,
+                self.remote_port,
+                self.remote_db))
 
         self.redis_pool = redis.ConnectionPool(
-                            host=self.remote_host, 
-                            port=self.remote_port, 
-                            db=self.remote_db, 
-                            password=self.remote_password
-                        )
+            host=self.remote_host,
+            port=self.remote_port,
+            db=self.remote_db,
+            password=self.remote_password
+        )
 
     def get_redis_connection(self):
         """
@@ -200,13 +203,13 @@ class FlatlandRemoteEvaluationService:
             redis_conn.ping()
         except Exception as e:
             raise Exception(
-                    "Unable to connect to redis server at {}:{} ."
-                    "Are you sure there is a redis-server running at the "
-                    "specified location ?".format(
-                        self.remote_host,
-                        self.remote_port
-                        )
-                    )
+                "Unable to connect to redis server at {}:{} ."
+                "Are you sure there is a redis-server running at the "
+                "specified location ?".format(
+                    self.remote_host,
+                    self.remote_port
+                )
+            )
         return redis_conn
 
     def _error_template(self, payload):
@@ -220,8 +223,8 @@ class FlatlandRemoteEvaluationService:
         return _response
 
     @timeout_decorator.timeout(
-                        PER_STEP_TIMEOUT,
-                        use_signals=use_signals_in_timeout)  # timeout for each command
+        PER_STEP_TIMEOUT,
+        use_signals=use_signals_in_timeout)  # timeout for each command
     def _get_next_command(self, _redis):
         """
         A low level wrapper for obtaining the next command from a 
@@ -231,7 +234,7 @@ class FlatlandRemoteEvaluationService:
         """
         command = _redis.brpop(self.command_channel)[1]
         return command
-    
+
     def get_next_command(self):
         """
         A helper function to obtain the next command, which transparently 
@@ -246,18 +249,18 @@ class FlatlandRemoteEvaluationService:
                 print("Command Service: ", command)
         except timeout_decorator.timeout_decorator.TimeoutError:
             raise Exception(
-                    "Timeout in step {} of simulation {}".format(
-                            self.current_step,
-                            self.simulation_count
-                            ))
+                "Timeout in step {} of simulation {}".format(
+                    self.current_step,
+                    self.simulation_count
+                ))
         command = msgpack.unpackb(
-                    command, 
-                    object_hook=m.decode, 
-                    encoding="utf8"
-                )
+            command,
+            object_hook=m.decode,
+            encoding="utf8"
+        )
         if self.verbose:
             print("Received Request : ", command)
-        
+
         return command
 
     def send_response(self, _command_response, command, suppress_logs=False):
@@ -266,15 +269,15 @@ class FlatlandRemoteEvaluationService:
 
         if self.verbose and not suppress_logs:
             print("Responding with : ", _command_response)
-        
+
         _redis.rpush(
-            command_response_channel, 
+            command_response_channel,
             msgpack.packb(
-                _command_response, 
-                default=m.encode, 
+                _command_response,
+                default=m.encode,
                 use_bin_type=True)
         )
-        
+
     def handle_ping(self, command):
         """
         Handles PING command from the client.
@@ -313,9 +316,9 @@ class FlatlandRemoteEvaluationService:
             )
             if self.visualize:
                 if self.env_renderer:
-                    del self.env_renderer     
+                    del self.env_renderer
                 self.env_renderer = RenderTool(self.env, gl="PILSVG", )
-            
+
             # Set max episode steps allowed
             self.env._max_episode_steps = \
                 int(1.5 * (self.env.width + self.env.height))
@@ -323,7 +326,7 @@ class FlatlandRemoteEvaluationService:
             if self.begin_simulation:
                 # If begin simulation has already been initialized 
                 # atleast once
-                self.simulation_times.append(time.time()-self.begin_simulation)
+                self.simulation_times.append(time.time() - self.begin_simulation)
             self.begin_simulation = time.time()
 
             self.simulation_rewards.append(0)
@@ -348,15 +351,15 @@ class FlatlandRemoteEvaluationService:
             _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
             _command_response['payload'] = {}
             _command_response['payload']['observation'] = False
-            _command_response['payload']['env_file_path'] = False            
+            _command_response['payload']['env_file_path'] = False
 
         self.send_response(_command_response, command)
         #####################################################################
         # Update evaluation state
         #####################################################################
         progress = np.clip(
-                    self.simulation_count * 1.0 / len(self.env_file_paths),
-                    0, 1)
+            self.simulation_count * 1.0 / len(self.env_file_paths),
+            0, 1)
         mean_reward = round(np.mean(self.simulation_rewards), 2)
         mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2)
         mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3)
@@ -399,9 +402,9 @@ class FlatlandRemoteEvaluationService:
         """
         self.simulation_rewards_normalized[-1] += \
             cumulative_reward / (
-                        self.env._max_episode_steps + 
-                        self.env.get_num_agents()
-                    )
+                self.env._max_episode_steps +
+                self.env.get_num_agents()
+            )
 
         if done["__all__"]:
             # Compute percentage complete
@@ -412,14 +415,14 @@ class FlatlandRemoteEvaluationService:
                     complete += 1
             percentage_complete = complete * 1.0 / self.env.get_num_agents()
             self.simulation_percentage_complete[-1] = percentage_complete
-        
+
         # Record Frame
         if self.visualize:
             self.env_renderer.render_env(
-                                show=False, 
-                                show_observations=False, 
-                                show_predictions=False
-                                )
+                show=False,
+                show_observations=False,
+                show_predictions=False
+            )
             """
             Only save the frames for environments which are separately provided 
             in video_generation_indices param
@@ -427,10 +430,10 @@ class FlatlandRemoteEvaluationService:
             current_env_path = self.env_file_paths[self.simulation_count]
             if current_env_path in self.video_generation_envs:
                 self.env_renderer.gl.save_image(
-                        os.path.join(
-                            self.vizualization_folder_name,
-                            "flatland_frame_{:04d}.png".format(self.record_frame_step)
-                        ))
+                    os.path.join(
+                        self.vizualization_folder_name,
+                        "flatland_frame_{:04d}.png".format(self.record_frame_step)
+                    ))
                 self.record_frame_step += 1
 
         # Build and send response
@@ -453,7 +456,7 @@ class FlatlandRemoteEvaluationService:
         _payload = command['payload']
 
         # Register simulation time of the last episode
-        self.simulation_times.append(time.time()-self.begin_simulation)
+        self.simulation_times.append(time.time() - self.begin_simulation)
 
         if len(self.simulation_rewards) != len(self.env_file_paths):
             raise Exception(
@@ -461,7 +464,7 @@ class FlatlandRemoteEvaluationService:
                 to operate on all the test environments.
                 """
             )
-        
+
         mean_reward = round(np.mean(self.simulation_rewards), 2)
         mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2)
         mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3)
@@ -473,7 +476,7 @@ class FlatlandRemoteEvaluationService:
             # install it by : 
             #
             # conda install -c conda-forge x264 ffmpeg
-            
+
             print("Generating Video from thumbnails...")
             video_output_path, video_thumb_output_path = \
                 aicrowd_helpers.generate_movie_from_frames(
@@ -518,14 +521,14 @@ class FlatlandRemoteEvaluationService:
         self.evaluation_state["score"]["score_secondary"] = mean_reward
         self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
         self.handle_aicrowd_success_event(self.evaluation_state)
-        print("#"*100)
+        print("#" * 100)
         print("EVALUATION COMPLETE !!")
-        print("#"*100)
+        print("#" * 100)
         print("# Mean Reward : {}".format(mean_reward))
         print("# Mean Normalized Reward : {}".format(mean_normalized_reward))
         print("# Mean Percentage Complete : {}".format(mean_percentage_complete))
-        print("#"*100)
-        print("#"*100)
+        print("#" * 100)
+        print("#" * 100)
 
     def report_error(self, error_message, command_response_channel):
         """
@@ -536,16 +539,16 @@ class FlatlandRemoteEvaluationService:
         _command_response['type'] = messages.FLATLAND_RL.ERROR
         _command_response['payload'] = error_message
         _redis.rpush(
-            command_response_channel, 
+            command_response_channel,
             msgpack.packb(
-                _command_response, 
-                default=m.encode, 
+                _command_response,
+                default=m.encode,
                 use_bin_type=True)
-            )
+        )
         self.evaluation_state["state"] = "ERROR"
         self.evaluation_state["error"] = error_message
         self.handle_aicrowd_error_event(self.evaluation_state)
-    
+
     def handle_aicrowd_info_event(self, payload):
         self.oracle_events.register_event(
             event_type=self.oracle_events.CROWDAI_EVENT_INFO,
@@ -577,17 +580,17 @@ class FlatlandRemoteEvaluationService:
                 print("Self.Reward : ", self.reward)
                 print("Current Simulation : ", self.simulation_count)
                 if self.env_file_paths and \
-                        self.simulation_count < len(self.env_file_paths):
+                    self.simulation_count < len(self.env_file_paths):
                     print("Current Env Path : ",
-                        self.env_file_paths[self.simulation_count])
+                          self.env_file_paths[self.simulation_count])
 
-            try:                
+            try:
                 if command['type'] == messages.FLATLAND_RL.PING:
                     """
                         INITIAL HANDSHAKE : Respond with PONG
                     """
                     self.handle_ping(command)
-                
+
                 elif command['type'] == messages.FLATLAND_RL.ENV_CREATE:
                     """
                         ENV_CREATE
@@ -612,8 +615,8 @@ class FlatlandRemoteEvaluationService:
                     self.handle_env_submit(command)
                 else:
                     _error = self._error_template(
-                                    "UNKNOWN_REQUEST:{}".format(
-                                        str(command)))
+                        "UNKNOWN_REQUEST:{}".format(
+                            str(command)))
                     if self.verbose:
                         print("Responding with : ", _error)
                     self.report_error(
@@ -631,10 +634,11 @@ class FlatlandRemoteEvaluationService:
 
 if __name__ == "__main__":
     import argparse
+
     parser = argparse.ArgumentParser(description='Submit the result to AIcrowd')
-    parser.add_argument('--service_id', 
-                        dest='service_id', 
-                        default='FLATLAND_RL_SERVICE_ID', 
+    parser.add_argument('--service_id',
+                        dest='service_id',
+                        default='FLATLAND_RL_SERVICE_ID',
                         required=False)
     parser.add_argument('--test_folder',
                         dest='test_folder',
@@ -642,16 +646,16 @@ if __name__ == "__main__":
                         help="Folder containing the files for the test envs",
                         required=False)
     args = parser.parse_args()
-    
+
     test_folder = args.test_folder
 
     grader = FlatlandRemoteEvaluationService(
-                test_env_folder=test_folder,
-                flatland_rl_service_id=args.service_id,
-                verbose=True,
-                visualize=True,
-                video_generation_envs=["Test_0/Level_1.pkl"]
-                )
+        test_env_folder=test_folder,
+        flatland_rl_service_id=args.service_id,
+        verbose=True,
+        visualize=True,
+        video_generation_envs=["Test_0/Level_1.pkl"]
+    )
     result = grader.run()
     if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
         cumulative_results = result['payload']
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 69be59ae2a957f6a2aaa948d9830472d7824516a..af1aad222919b00b716dd9da0f3be9534d54e411 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -11,9 +11,9 @@ from numpy import array
 import flatland.utils.rendertools as rt
 from flatland.core.grid.grid4_utils import mirror
 from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
-from flatland.envs.generators import complex_rail_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, random_rail_generator
+from flatland.envs.rail_generators import complex_rail_generator, empty_rail_generator
 
 
 class EditorMVC(object):
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 6a0a9282614c0319338454f5b8ae97531b12e432..92a0f84f35fa942b03236c6add6e722475a2d842 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -172,7 +172,7 @@ class PILGL(GraphicsLayer):
     def text(self, xPx, yPx, strText, layer=RAIL_LAYER):
         xyPixLeftTop = (xPx, yPx)
         self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
-        
+
     def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
         print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer)
         xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
@@ -500,9 +500,9 @@ class PILSVG(PILGL):
                                           False)[0]
         self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER)
 
-    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, 
-            show_debug=True):
-        
+    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None,
+                    show_debug=True):
+
         if binary_trans in self.pil_rail:
             pil_track = self.pil_rail[binary_trans]
             if target is not None:
@@ -510,7 +510,7 @@ class PILSVG(PILGL):
                 target_img = Image.alpha_composite(pil_track, target_img)
                 self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER)
                 if show_debug:
-                    self.text_rowcol((row+0.8, col+0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
+                    self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
 
             if binary_trans == 0:
                 if self.background_grid[col][row] <= 4:
@@ -607,7 +607,7 @@ class PILSVG(PILGL):
 
         if show_debug:
             print("Call text:")
-            self.text_rowcol((row+0.2, col+0.2,), str(agent_idx))
+            self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
 
     def set_cell_occupied(self, agent_idx, row, col):
         occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
diff --git a/notebooks/simple_example1_env_from_tuple.ipynb b/notebooks/simple_example1_env_from_tuple.ipynb
index 3fd55bc8fafabdd57eb43e012c5d98b37c73d496..0fcfe26325e5778cbbfdf66591346972edb2406d 100644
--- a/notebooks/simple_example1_env_from_tuple.ipynb
+++ b/notebooks/simple_example1_env_from_tuple.ipynb
@@ -14,7 +14,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from flatland.envs.generators import rail_from_manual_specifications_generator\n",
+    "from flatland.envs.rail_generators import rail_from_manual_specifications_generator\n",
     "from flatland.envs.observations import TreeObsForRailEnv\n",
     "from flatland.envs.rail_env import RailEnv\n",
     "from flatland.utils.rendertools import RenderTool\n",
diff --git a/notebooks/simple_example2_generate_random_rail.ipynb b/notebooks/simple_example2_generate_random_rail.ipynb
index b9d4a96c02d4cb63392654025c6857c8b6764d1a..19b854ee15d8dd1e19361f58a552eba617b19b67 100644
--- a/notebooks/simple_example2_generate_random_rail.ipynb
+++ b/notebooks/simple_example2_generate_random_rail.ipynb
@@ -15,7 +15,7 @@
    "source": [
     "import random\n",
     "import numpy as np\n",
-    "from flatland.envs.generators import random_rail_generator\n",
+    "from flatland.envs.rail_generators import random_rail_generator\n",
     "from flatland.envs.observations import TreeObsForRailEnv\n",
     "from flatland.envs.rail_env import RailEnv\n",
     "from flatland.utils.rendertools import RenderTool\n",
diff --git a/notebooks/simple_example_3_manual_control.ipynb b/notebooks/simple_example_3_manual_control.ipynb
index 50f228055b320c15e0411c2b086254a1f4d4ceef..cb2b377765f375e8d77d8374fdcc3bb67ce06444 100644
--- a/notebooks/simple_example_3_manual_control.ipynb
+++ b/notebooks/simple_example_3_manual_control.ipynb
@@ -40,7 +40,7 @@
     "import random\n",
     "import numpy as np\n",
     "import time\n",
-    "from flatland.envs.generators import random_rail_generator\n",
+    "from flatland.envs.rail_generators import random_rail_generator\n",
     "from flatland.envs.observations import TreeObsForRailEnv\n",
     "from flatland.envs.rail_env import RailEnv\n",
     "from flatland.utils.rendertools import RenderTool"
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 62df397d34b06755465ca0c9f664b9117c87243f..e5e89f76428bb881d0f72aa60aada97ab02167a5 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -2,10 +2,11 @@ import numpy as np
 
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
 
 
 def test_walker():
@@ -27,6 +28,7 @@ def test_walker():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2,
                                                        predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 574705c49501415f37149bd4b3d870665bf06e60..c96e8db00fe721f42667aed4833d034a47f19156 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -5,10 +5,11 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.agent_utils import EnvAgent
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail
 
@@ -21,6 +22,7 @@ def test_global_obs():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
 
@@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 98c276f894b51685ce0edf43f6bd1b1137d46eb0..09f7e5e67a15c55b5070ac8679e43ecc9a14b9da 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -5,10 +5,11 @@ import pprint
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
 
@@ -21,6 +22,7 @@ def test_dummy_predictor(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
                   )
@@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -230,6 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 7ebbbb1461e24aae4c2319f51a9bb4abb2d3b25c..d5dc3ac7af4be6ebd8c5cbeaf705bb710d36d138 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -6,10 +6,11 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.generators import complex_rail_generator
-from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
 
 """Tests for `flatland` package."""
 
@@ -26,6 +27,7 @@ def test_load_env():
 def test_save_load():
     env = RailEnv(width=10, height=10,
                   rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=2)
     env.reset()
     agent_1_pos = env.agents_static[0].position
@@ -77,6 +79,7 @@ def test_rail_environment_single_agent():
     rail_env = RailEnv(width=3,
                        height=3,
                        rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
@@ -156,6 +159,7 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
@@ -200,6 +204,7 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
similarity index 87%
rename from tests/test_flatland_env_sparse_rail_generator.py
rename to tests/test_flatland_envs_sparse_rail_generator.py
index d59e684575e9410b2859bb011ecb835a267b1c36..db7cac61f4cf3bec4a330694c1864ef7d82bd076 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -1,6 +1,7 @@
-from flatland.envs.generators import sparse_rail_generator
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 
@@ -16,6 +17,7 @@ def test_sparse_rail_generator():
                                                        seed=5,  # Random seed
                                                        realistic_mode=False  # Ordered distribution of nodes
                                                        ),
+                  schedule_generator=sparse_schedule_generator(),
                   number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 67dcd25c0769e542fd9a03502c2a8c1b29333b2b..eaf782df3255ecfc6ebaa7078935f485497ed359 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,8 +1,9 @@
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 
 
 class SingleAgentNavigationObs(TreeObsForRailEnv):
@@ -62,6 +63,7 @@ def test_malfunction_process():
                   height=20,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                         seed=0),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=SingleAgentNavigationObs(),
                   stochastic_data=stochastic_data)
diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py
index ac5d7f4132b5c224206af3602b6c1341fe026d8b..8248c675995fc5c906e82d8650a5b619e7b038f2 100644
--- a/tests/test_flatland_utils_rendertools.py
+++ b/tests/test_flatland_utils_rendertools.py
@@ -11,9 +11,9 @@ from importlib_resources import path
 
 import flatland.utils.rendertools as rt
 import images.test
-from flatland.envs.generators import empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import empty_rail_generator
 
 
 def checkFrozenImage(oRT, sFileImage, resave=False):
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 47aadee73cfc0b45dc701fe914a586feb31b2597..8de36c81e4a13c0b7e7e5e556ad79234503ad31a 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,10 +1,12 @@
 import numpy as np
 
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
 
 np.random.seed(1)
 
+
 # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
 # Training on simple small tasks is the best way to get familiar with the environment
 #
@@ -46,6 +48,7 @@ def test_multi_speed_init():
                   height=50,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
                                                         seed=0),
+                  schedule_generator=complex_schedule_generator(),
                   number_of_agents=5)
     # Initialize the agent with the parameters corresponding to the environment and observation_builder
     agent = RandomAgent(218, 4)
diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff5ee56a308ce19559d079b716bde90ad65baf11
--- /dev/null
+++ b/tests/test_speed_classes.py
@@ -0,0 +1,36 @@
+"""Test speed initialization by a map of speeds and their corresponding ratios."""
+import numpy as np
+
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator
+
+
+def test_speed_initialization_helper():
+    np.random.seed(1)
+    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
+    actual_speeds = speed_initialization_helper(10, speed_ratio_map)
+
+    # seed makes speed_initialization_helper deterministic -> check generated speeds.
+    assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2]
+
+
+def test_rail_env_speed_intializer():
+    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
+
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
+                                                        seed=0),
+                  schedule_generator=complex_schedule_generator(),
+                  number_of_agents=10)
+    env.reset()
+    actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
+
+    expected_speed_set = set(speed_ratio_map.keys())
+
+    # check that the number of speeds generated is correct
+    assert len(actual_speeds) == env.get_num_agents()
+
+    # check that only the speeds defined are generated
+    assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds})
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index f97b071e6b33c099efa5af36766e159e57716443..610022cafe12fccb2cbbd5da57006e61c89faf28 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -3,11 +3,13 @@
 
 import numpy as np
 
-from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
-    random_rail_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
+    random_rail_generator, empty_rail_generator
+from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \
+    agents_from_file
 from flatland.utils.simple_rail import make_simple_rail
 
 
@@ -58,7 +60,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  schedule_generator=complex_schedule_generator()
                   )
     assert env.get_num_agents() == 2
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -69,7 +72,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  schedule_generator=complex_schedule_generator()
                   )
     assert env.get_num_agents() == 0
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -82,7 +86,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  schedule_generator=complex_schedule_generator()
                   )
     assert env.get_num_agents() == n_agents
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=n_agents
                   )
     nr_rail_elements = np.count_nonzero(env.rail.grid)
@@ -118,6 +124,7 @@ def tests_rail_from_file():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
                   number_of_agents=3,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -130,6 +137,7 @@ def tests_rail_from_file():
     env = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
+                  schedule_generator=agents_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -151,6 +159,7 @@ def tests_rail_from_file():
     env2 = RailEnv(width=rail_map.shape[1],
                    height=rail_map.shape[0],
                    rail_generator=rail_from_grid_transition_map(rail),
+                   schedule_generator=random_schedule_generator(),
                    number_of_agents=3,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -164,6 +173,7 @@ def tests_rail_from_file():
     env2 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name_2),
+                   schedule_generator=agents_from_file(file_name_2),
                    number_of_agents=1,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -180,6 +190,7 @@ def tests_rail_from_file():
     env3 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name),
+                   schedule_generator=agents_from_file(file_name),
                    number_of_agents=1,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -197,6 +208,7 @@ def tests_rail_from_file():
     env4 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name_2),
+                   schedule_generator=agents_from_file(file_name_2),
                    number_of_agents=1,
                    obs_builder_object=TreeObsForRailEnv(max_depth=2),
                    )