diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a25862dce70a32f21057f25875ec6baaecf2206f..0f215709ebdd96298e2fe576d0b2244dc699f5c7 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -6,6 +6,7 @@ from typing import Callable, Tuple, Optional, Dict, List, Any import msgpack import numpy as np +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \ Vec2dOperations @@ -641,11 +642,15 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius)) return city_positions, city_cells - def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, rails_between_cities: int, rails_in_city: int = 2): - inner_connection_points = [] - outer_connection_points = [] - connection_info = [] - city_orientations = [] + def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, rails_between_cities: int, + rails_in_city: int = 2) -> (List[List[List[IntVector2D]]], + List[List[List[IntVector2D]]], + List[np.ndarray], + List[Grid4TransitionsEnum]): + inner_connection_points: List[List[List[IntVector2D]]] = [] + outer_connection_points: List[List[List[IntVector2D]]] = [] + connection_info: List[np.ndarray] = [] + city_orientations: List[Grid4TransitionsEnum] = [] for city_position in city_positions: # Chose the directions where close cities are situated @@ -669,8 +674,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ nr_of_connection_points = np.random.randint(3, rails_in_city + 1) for idx in connection_sides_idx: connections_per_direction[idx] = nr_of_connection_points - connection_points_coordinates_inner = [[] for i in range(4)] - connection_points_coordinates_outer = [[] for i in range(4)] + connection_points_coordinates_inner: List[List[IntVector2D]] = [[] for i in range(4)] + connection_points_coordinates_outer: List[List[IntVector2D]] = [[] for i in range(4)] number_of_out_rails = np.random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1) start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) for direction in range(4): @@ -699,7 +704,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ return inner_connection_points, outer_connection_points, connection_info, city_orientations def _connect_cities(city_positions: IntVector2DArray, connection_points, city_cells: IntVector2DArray, - rail_trans, grid_map): + rail_trans, grid_map: GridTransitionMap): """ Function to connect the different cities through their connection points :param city_positions: Positions of city centers @@ -739,7 +744,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ return all_paths def _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, rail_trans, - grid_map): + grid_map: GridTransitionMap): """ Builds inner city tracks. This current version connects all incoming connections to all outgoing connections :param city_positions: Positions of the cities @@ -831,7 +836,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ agent_start_targets_cities.append((start_city, target_city, city_orientation[start_city])) return agent_start_targets_cities, num_agents - def _fix_transitions(city_cells, inter_city_lines, grid_map): + def _fix_transitions(city_cells, inter_city_lines, grid_map: GridTransitionMap): """ Function to fix all transition elements in environment """