diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index 515d6c1b0469b7fbd9bad8cd82a40db7766f6219..f6bd2bda9d4c0b8ebf0658759efa367ceeb0a098 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -1,9 +1,12 @@
 import random
+from typing import Any
 
 import numpy as np
 
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_generators import AgentGenerator, AgentGeneratorProduct
+from flatland.envs.generators import RailGenerator, RailGeneratorProduct
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -11,20 +14,28 @@ random.seed(100)
 np.random.seed(100)
 
 
-def custom_rail_generator():
-    def generator(width, height, num_agents=0, num_resets=0):
+def custom_rail_generator() -> RailGenerator:
+    def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
         rail_array.fill(0)
         new_tran = rail_trans.set_transition(1, 1, 1, 1)
         print(new_tran)
+        rail_array[0, 0] = new_tran
+        rail_array[0, 1] = new_tran
+        return grid_map, None
+
+    return generator
+
+
+def custom_agent_generator() -> AgentGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
         agents_positions = []
         agents_direction = []
         agents_target = []
-        rail_array[0, 0] = new_tran
-        rail_array[0, 1] = new_tran
-        return grid_map, agents_positions, agents_direction, agents_target
+        speeds = []
+        return agents_positions, agents_direction, agents_target, speeds
 
     return generator
 
diff --git a/flatland/envs/agent_generators.py b/flatland/envs/agent_generators.py
index 1f769b7d65918ab1cb541f80dd2c60ac43526774..c03511bc186bb41554eecb8c1c62b73482012ed0 100644
--- a/flatland/envs/agent_generators.py
+++ b/flatland/envs/agent_generators.py
@@ -73,7 +73,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] =
         initial positions, directions, targets speeds
     """
 
-    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
         def _path_exists(rail, start, direction, end):
             # BFS - Check if a path exists between the 2 nodes
 
@@ -165,7 +165,7 @@ def agents_from_file(filename) -> AgentGenerator:
         initial positions, directions, targets speeds
     """
 
-    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
         with open(filename, "rb") as file_in:
             load_data = file_in.read()
         data = msgpack.unpackb(load_data, use_list=False)
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 380bf37f8de6b926e95bd1f805b71a21ea0b84e2..5e97f1588b38d5e054607d1ba07d7973c212a2f4 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -10,7 +10,8 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_rail
 
-RailGenerator = Callable[[int, int, int, int], Tuple[GridTransitionMap, Optional[Any]]]
+RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
+RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 
 
 def empty_rail_generator() -> RailGenerator:
@@ -19,13 +20,13 @@ def empty_rail_generator() -> RailGenerator:
     Primarily used by the editor
     """
 
-    def generator(width, height, num_agents=0, num_resets=0):
+    def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
         rail_array.fill(0)
 
-        return [grid_map, None]
+        return grid_map, None
 
     return generator
 
@@ -249,8 +250,8 @@ def rail_from_grid_transition_map(rail_map) -> RailGenerator:
         Generator function that always returns the given `rail_map' object.
     """
 
-    def generator(width, height, num_agents, num_resets=0):
-        return [rail_map, None]
+    def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
+        return rail_map, None
 
     return generator
 
@@ -287,7 +288,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
         The matrix with the correct 16-bit bitmaps for each cell.
     """
 
-    def generator(width, height, num_agents, num_resets=0):
+    def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
         t_utils = RailEnvTransitions()
 
         transition_probability = cell_type_relative_proportion
@@ -519,6 +520,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
         return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
         return_rail.grid = tmp_rail
 
-        return [return_rail, None]
+        return return_rail, None
 
     return generator