diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 574a3ee80cb735a625b6685955de8f3417997b38..fa18298c4eded5b7ba097e2e10fa73e6ef705e0a 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -581,7 +581,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
         # Connect the cities through the connection points
         city_connection_time = time.time()
-        inter_city_lines = _connect_cities(city_positions, outer_connection_points, connection_info, city_cells,
+        inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells,
                                            rail_trans, grid_map)
         if DEBUG_PRINT_TIMING:
             print("City connection time", time.time() - city_connection_time)
@@ -622,59 +622,59 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         }}
 
     def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
-        node_positions: List[Tuple[int, int]] = []
+        city_positions: List[Tuple[int, int]] = []
         city_cells: List[Tuple[int, int]] = []
-        for node_idx in range(num_cities):
-            to_close = True
+        for city_idx in range(num_cities):
+            too_close = True
             tries = 0
 
-            while to_close:
-                x_tmp = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1))
-                y_tmp = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1))
-                to_close = False
+            while too_close:
+                row = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1))
+                col = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1))
+                too_close = False
                 # Check distance to nodes
-                for node_pos in node_positions:
-                    if _are_cities_overlapping((x_tmp, y_tmp), node_pos, 2 * (city_radius + 1) + 1):
-                        to_close = True
+                for node_pos in city_positions:
+                    if _are_cities_overlapping((row, col), node_pos, 2 * (city_radius + 1) + 1):
+                        too_close = True
 
-                if not to_close:
-                    node_positions.append((x_tmp, y_tmp))
-                    city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius))
+                if not too_close:
+                    city_positions.append((row, col))
+                    city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius))
 
                 tries += 1
                 if tries > 200:
                     warnings.warn(
-                        "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
-                            len(node_positions),
+                        "Could not only set {} cities after {} tries, although {} of cities required to be generated!".format(
+                            len(city_positions),
                             tries, num_cities))
                     break
-        return node_positions, city_cells
+        return city_positions, city_cells
 
     def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
-        nodes_ratio = height / width
-        nodes_per_row = int(np.ceil(np.sqrt(num_cities * nodes_ratio)))
-        nodes_per_col = int(np.ceil(num_cities / nodes_per_row))
-        x_positions = np.linspace(city_radius + 1, height - city_radius - 2, nodes_per_row, dtype=int)
-        y_positions = np.linspace(city_radius + 1, width - city_radius - 2, nodes_per_col, dtype=int)
-        node_positions = []
+        aspect_ratio = height / width
+        cities_per_row = int(np.ceil(np.sqrt(num_cities * aspect_ratio)))
+        cities_per_col = int(np.ceil(num_cities / cities_per_row))
+        row_positions = np.linspace(city_radius + 1, height - city_radius - 2, cities_per_row, dtype=int)
+        col_positions = np.linspace(city_radius + 1, width - city_radius - 2, cities_per_col, dtype=int)
+        city_positions = []
         city_cells = []
-        for node_idx in range(num_cities):
-            x_tmp = x_positions[node_idx % nodes_per_row]
-            y_tmp = y_positions[node_idx // nodes_per_row]
-            node_positions.append((x_tmp, y_tmp))
-            city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius))
-        return node_positions, city_cells
-
-    def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2):
+        for city_idx in range(num_cities):
+            row = row_positions[city_idx % cities_per_row]
+            col = col_positions[city_idx // cities_per_row]
+            city_positions.append((row, col))
+            city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius))
+        return city_positions, city_cells
+
+    def _generate_node_connection_points(city_positions: List[Tuple[int, int]], city_radius: int, rails_between_cities: int, rails_in_city: int = 2):
         inner_connection_points = []
         outer_connection_points = []
         connection_info = []
         city_orientations = []
-        for node_position in node_positions:
+        for node_position in city_positions:
 
             # Chose the directions where close cities are situated
             neighb_dist = []
-            for neighb_node in node_positions:
+            for neighb_node in city_positions:
                 neighb_dist.append(distance_on_rail(node_position, neighb_node, metric="Manhattan"))
             closest_neighb_idx = argsort(neighb_dist)
 
@@ -684,18 +684,18 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             if grid_mode:
                 current_closest_direction = np.random.randint(4)
             else:
-                current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
+                current_closest_direction = direction_to_point(node_position, city_positions[closest_neighb_idx[idx]])
             connection_sides_idx.append(current_closest_direction)
             connection_sides_idx.append((current_closest_direction + 2) % 4)
             city_orientations.append(current_closest_direction)
             # set the number of tracks within a city, at least 2 tracks per city
             connections_per_direction = np.zeros(4, dtype=int)
-            nr_of_connection_points = np.random.randint(3, tracks_in_city + 1)
+            nr_of_connection_points = np.random.randint(3, rails_in_city + 1)
             for idx in connection_sides_idx:
                 connections_per_direction[idx] = nr_of_connection_points
             connection_points_coordinates_inner = [[] for i in range(4)]
             connection_points_coordinates_outer = [[] for i in range(4)]
-            number_of_out_rails = np.random.randint(1, min(max_inter_city_rails_allowed, nr_of_connection_points) + 1)
+            number_of_out_rails = np.random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1)
             start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
             for direction in range(4):
                 connection_slots = np.arange(connections_per_direction[direction]) - int(
@@ -703,16 +703,16 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                 for connection_idx in range(connections_per_direction[direction]):
                     if direction == 0:
                         tmp_coordinates = (
-                            node_position[0] - node_size, node_position[1] + connection_slots[connection_idx])
+                            node_position[0] - city_radius, node_position[1] + connection_slots[connection_idx])
                     if direction == 1:
                         tmp_coordinates = (
-                            node_position[0] + connection_slots[connection_idx], node_position[1] + node_size)
+                            node_position[0] + connection_slots[connection_idx], node_position[1] + city_radius)
                     if direction == 2:
                         tmp_coordinates = (
-                            node_position[0] + node_size, node_position[1] + connection_slots[connection_idx])
+                            node_position[0] + city_radius, node_position[1] + connection_slots[connection_idx])
                     if direction == 3:
                         tmp_coordinates = (
-                            node_position[0] + connection_slots[connection_idx], node_position[1] - node_size)
+                            node_position[0] + connection_slots[connection_idx], node_position[1] - city_radius)
                     connection_points_coordinates_inner[direction].append(tmp_coordinates)
                     if connection_idx in range(start_idx, start_idx + number_of_out_rails + 1):
                         connection_points_coordinates_outer[direction].append(tmp_coordinates)
@@ -722,23 +722,22 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             connection_info.append(connections_per_direction)
         return inner_connection_points, outer_connection_points, connection_info, city_orientations
 
-    def _connect_cities(node_positions, connection_points, connection_info, city_cells,
+    def _connect_cities(city_positions: List[Tuple[int, int]], connection_points, city_cells: List[Tuple[int, int]],
                         rail_trans, grid_map):
         """
         Function to connect the different cities through their connection points
         :param city_positions: Positions of city centers
         :param connection_points: Boarder connection points of cities
-        :param connection_info: Number of connection points per direction NESW
         :param rail_trans: Transitions
         :param grid_map: Grid map
         :return:
         """
         all_paths = []
 
-        for current_node in np.arange(len(node_positions)):
-            neighbours = _closest_neigh_in_direction(current_node, node_positions)
+        for current_city_idx in np.arange(len(city_positions)):
+            neighbours = _closest_neighbour_in_direction(current_city_idx, city_positions)
             for out_direction in range(4):
-                for tmp_out_connection_point in connection_points[current_node][out_direction]:
+                for tmp_out_connection_point in connection_points[current_city_idx][out_direction]:
                     # This only needs to be checked when entering this loop
                     neighb_idx = neighbours[out_direction]
                     if neighb_idx is None:
@@ -879,10 +878,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         for cell in range(rails_to_fix_cnt):
             grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]))
 
-    def _closest_neigh_in_direction(current_node, node_positions):
+    def _closest_neighbour_in_direction(current_city_idx: int, node_positions: List[Tuple[int, int]]):
         """
         Returns indices of closest neighbours in every direction NESW
-        :param current_node: Index of node in city_positions list
+        :param current_city_idx: Index of node in city_positions list
         :param city_positions: list of all points being considered
         :return: list of index of closest neighbours in all directions
         """
@@ -890,11 +889,11 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         closest_neighb = [None for i in range(4)]
         for av_node in range(len(node_positions)):
             node_dist.append(
-                distance_on_rail(node_positions[current_node], node_positions[av_node], metric="Manhattan"))
+                distance_on_rail(node_positions[current_city_idx], node_positions[av_node], metric="Manhattan"))
         sorted_neighbours = np.argsort(node_dist)
         direction_set = 0
         for neighb in sorted_neighbours[1:]:
-            direction_to_neighb = direction_to_point(node_positions[current_node], node_positions[neighb])
+            direction_to_neighb = direction_to_point(node_positions[current_city_idx], node_positions[neighb])
             if closest_neighb[direction_to_neighb] == None:
                 closest_neighb[direction_to_neighb] = neighb
                 direction_set += 1
@@ -909,10 +908,16 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
     def _get_cells_in_city(center: Tuple[int, int], radius: int) -> List[Tuple[int, int]]:
         """
-        Function to return all cells within a city
-        :param center: center coordinates of city
-        :param radius: radius of city (it is a square)
-        :return: returns flat list of all cell coordinates in the city
+
+        Parameters
+        ----------
+        center center coordinates of city
+        radius radius of city (it is a square)
+
+        Returns
+        -------
+        flat list of all cell coordinates in the city
+
         """
         x_range = np.arange(center[0] - radius, center[0] + radius + 1)
         y_range = np.arange(center[1] - radius, center[1] + radius + 1)
@@ -923,23 +928,4 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
     def _are_cities_overlapping(center_1, center_2, radius):
         return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius
 
-    def _track_number(city_position, city_orientation, position):
-        """
-        FUnction that tells you if you are on even or uneven track number
-        :param city_position:
-        :param city_orientation:
-        :param position:
-        :return:
-        """
-        if city_orientation % 2 == 0:
-            if city_position[1] - position[1] < 0:
-                return np.abs(city_position[1] - position[1]) % 2
-            else:
-                return (np.abs(city_position[1] - position[1]) + 1) % 2
-        else:
-            if city_position[0] - position[0] > 0:
-                return np.abs(city_position[0] - position[0]) % 2
-            else:
-                return (np.abs(city_position[0] - position[0]) + 1) % 2
-
     return generator