diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index a25862dce70a32f21057f25875ec6baaecf2206f..0f215709ebdd96298e2fe576d0b2244dc699f5c7 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -6,6 +6,7 @@ from typing import Callable, Tuple, Optional, Dict, List, Any
 import msgpack
 import numpy as np
 
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \
     Vec2dOperations
@@ -641,11 +642,15 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius))
         return city_positions, city_cells
 
-    def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, rails_between_cities: int, rails_in_city: int = 2):
-        inner_connection_points = []
-        outer_connection_points = []
-        connection_info = []
-        city_orientations = []
+    def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, rails_between_cities: int,
+                                         rails_in_city: int = 2) -> (List[List[List[IntVector2D]]],
+                                                                     List[List[List[IntVector2D]]],
+                                                                     List[np.ndarray],
+                                                                     List[Grid4TransitionsEnum]):
+        inner_connection_points: List[List[List[IntVector2D]]] = []
+        outer_connection_points: List[List[List[IntVector2D]]] = []
+        connection_info: List[np.ndarray] = []
+        city_orientations: List[Grid4TransitionsEnum] = []
         for city_position in city_positions:
 
             # Chose the directions where close cities are situated
@@ -669,8 +674,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             nr_of_connection_points = np.random.randint(3, rails_in_city + 1)
             for idx in connection_sides_idx:
                 connections_per_direction[idx] = nr_of_connection_points
-            connection_points_coordinates_inner = [[] for i in range(4)]
-            connection_points_coordinates_outer = [[] for i in range(4)]
+            connection_points_coordinates_inner: List[List[IntVector2D]] = [[] for i in range(4)]
+            connection_points_coordinates_outer: List[List[IntVector2D]] = [[] for i in range(4)]
             number_of_out_rails = np.random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1)
             start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
             for direction in range(4):
@@ -699,7 +704,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         return inner_connection_points, outer_connection_points, connection_info, city_orientations
 
     def _connect_cities(city_positions: IntVector2DArray, connection_points, city_cells: IntVector2DArray,
-                        rail_trans, grid_map):
+                        rail_trans, grid_map: GridTransitionMap):
         """
         Function to connect the different cities through their connection points
         :param city_positions: Positions of city centers
@@ -739,7 +744,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         return all_paths
 
     def _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, rail_trans,
-                            grid_map):
+                            grid_map: GridTransitionMap):
         """
         Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
         :param city_positions: Positions of the cities
@@ -831,7 +836,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             agent_start_targets_cities.append((start_city, target_city, city_orientation[start_city]))
         return agent_start_targets_cities, num_agents
 
-    def _fix_transitions(city_cells, inter_city_lines, grid_map):
+    def _fix_transitions(city_cells, inter_city_lines, grid_map: GridTransitionMap):
         """
         Function to fix all transition elements in environment
         """