diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 355f5502992a34f8d58d4dbd80028eb4dd71cc48..79e0ac7d429f3d22982f4510fbce8868e547bbc5 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,3 +1,5 @@
+from typing import Mapping, Tuple, List, Callable
+
 import msgpack
 import numpy as np
 
@@ -27,7 +29,12 @@ def empty_rail_generator():
     return generator
 
 
-def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
+def complex_rail_generator(nr_start_goal=1,
+                           nr_extra=100,
+                           min_dist=20,
+                           max_dist=99999,
+                           seed=0,
+                           speed_initializer: Callable[[int], List[float]] = None):
     """
     Parameters
     -------
@@ -35,6 +42,8 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         The width (number of cells) of the grid to generate.
     height : int
         The height (number of cells) of the grid to generate.
+    speed_initializer : Callable[[int], List[float]]
+        Function that returns a list of speeds for the numer of agents given as argument.
 
     Returns
     -------
@@ -145,7 +154,11 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         agents_target = [sg[1] for sg in start_goal[:num_agents]]
         agents_direction = start_dir[:num_agents]
 
-        return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        if speed_initializer:
+            speeds = speed_initializer(num_agents)
+        else:
+            speeds = [1.0] * len(agents_position)
+        return grid_map, agents_position, agents_direction, agents_target, speeds
 
     return generator
 
@@ -538,3 +551,24 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
         return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
+
+
+def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float]) -> List[float]:
+    """
+    Parameters
+    -------
+    nb_agents : int
+        The number of agents to generate a speed for
+    speed_ratio_map : Mapping[float,float]
+        A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
+
+    Returns
+    -------
+    List[float]
+        A list of size nb_agents of speeds with the corresponding probabilistic ratios.
+    """
+    nb_classes = len(speed_ratio_map.keys())
+    speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
+    speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
+    speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
+    return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ef600d949fb9ccfa1fe5c22d276163bbeb4e736
--- /dev/null
+++ b/tests/test_speed_classes.py
@@ -0,0 +1,37 @@
+"""Test speed initialization by a map of speeds and their corresponding ratios."""
+import numpy as np
+
+from flatland.envs.generators import speed_initialization_helper, complex_rail_generator
+from flatland.envs.rail_env import RailEnv
+
+
+def test_speed_initialization_helper():
+    np.random.seed(1)
+    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
+    actual_speeds = speed_initialization_helper(10, speed_ratio_map)
+
+    # seed makes speed_initialization_helper deterministic -> check generated speeds.
+    assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2]
+
+
+def test_rail_env_speed_intializer():
+    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
+
+    def my_speed_initializer(nb_agents):
+        return speed_initialization_helper(nb_agents, speed_ratio_map)
+
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
+                                                        seed=0, speed_initializer=my_speed_initializer),
+                  number_of_agents=10)
+    env.reset()
+    actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
+
+    expected_speed_set = set(speed_ratio_map.keys())
+
+    # check that the number of speeds generated is correct
+    assert len(actual_speeds) == env.get_num_agents()
+
+    # check that only the speeds defined are generated
+    assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds})