diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 574a3ee80cb735a625b6685955de8f3417997b38..fa18298c4eded5b7ba097e2e10fa73e6ef705e0a 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -581,7 +581,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ # Connect the cities through the connection points city_connection_time = time.time() - inter_city_lines = _connect_cities(city_positions, outer_connection_points, connection_info, city_cells, + inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells, rail_trans, grid_map) if DEBUG_PRINT_TIMING: print("City connection time", time.time() - city_connection_time) @@ -622,59 +622,59 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ }} def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]): - node_positions: List[Tuple[int, int]] = [] + city_positions: List[Tuple[int, int]] = [] city_cells: List[Tuple[int, int]] = [] - for node_idx in range(num_cities): - to_close = True + for city_idx in range(num_cities): + too_close = True tries = 0 - while to_close: - x_tmp = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1)) - y_tmp = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) - to_close = False + while too_close: + row = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1)) + col = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) + too_close = False # Check distance to nodes - for node_pos in node_positions: - if _are_cities_overlapping((x_tmp, y_tmp), node_pos, 2 * (city_radius + 1) + 1): - to_close = True + for node_pos in city_positions: + if _are_cities_overlapping((row, col), node_pos, 2 * (city_radius + 1) + 1): + too_close = True - if not to_close: - node_positions.append((x_tmp, y_tmp)) - city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius)) + if not too_close: + city_positions.append((row, col)) + city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius)) tries += 1 if tries > 200: warnings.warn( - "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format( - len(node_positions), + "Could not only set {} cities after {} tries, although {} of cities required to be generated!".format( + len(city_positions), tries, num_cities)) break - return node_positions, city_cells + return city_positions, city_cells def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]): - nodes_ratio = height / width - nodes_per_row = int(np.ceil(np.sqrt(num_cities * nodes_ratio))) - nodes_per_col = int(np.ceil(num_cities / nodes_per_row)) - x_positions = np.linspace(city_radius + 1, height - city_radius - 2, nodes_per_row, dtype=int) - y_positions = np.linspace(city_radius + 1, width - city_radius - 2, nodes_per_col, dtype=int) - node_positions = [] + aspect_ratio = height / width + cities_per_row = int(np.ceil(np.sqrt(num_cities * aspect_ratio))) + cities_per_col = int(np.ceil(num_cities / cities_per_row)) + row_positions = np.linspace(city_radius + 1, height - city_radius - 2, cities_per_row, dtype=int) + col_positions = np.linspace(city_radius + 1, width - city_radius - 2, cities_per_col, dtype=int) + city_positions = [] city_cells = [] - for node_idx in range(num_cities): - x_tmp = x_positions[node_idx % nodes_per_row] - y_tmp = y_positions[node_idx // nodes_per_row] - node_positions.append((x_tmp, y_tmp)) - city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius)) - return node_positions, city_cells - - def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2): + for city_idx in range(num_cities): + row = row_positions[city_idx % cities_per_row] + col = col_positions[city_idx // cities_per_row] + city_positions.append((row, col)) + city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius)) + return city_positions, city_cells + + def _generate_node_connection_points(city_positions: List[Tuple[int, int]], city_radius: int, rails_between_cities: int, rails_in_city: int = 2): inner_connection_points = [] outer_connection_points = [] connection_info = [] city_orientations = [] - for node_position in node_positions: + for node_position in city_positions: # Chose the directions where close cities are situated neighb_dist = [] - for neighb_node in node_positions: + for neighb_node in city_positions: neighb_dist.append(distance_on_rail(node_position, neighb_node, metric="Manhattan")) closest_neighb_idx = argsort(neighb_dist) @@ -684,18 +684,18 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ if grid_mode: current_closest_direction = np.random.randint(4) else: - current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) + current_closest_direction = direction_to_point(node_position, city_positions[closest_neighb_idx[idx]]) connection_sides_idx.append(current_closest_direction) connection_sides_idx.append((current_closest_direction + 2) % 4) city_orientations.append(current_closest_direction) # set the number of tracks within a city, at least 2 tracks per city connections_per_direction = np.zeros(4, dtype=int) - nr_of_connection_points = np.random.randint(3, tracks_in_city + 1) + nr_of_connection_points = np.random.randint(3, rails_in_city + 1) for idx in connection_sides_idx: connections_per_direction[idx] = nr_of_connection_points connection_points_coordinates_inner = [[] for i in range(4)] connection_points_coordinates_outer = [[] for i in range(4)] - number_of_out_rails = np.random.randint(1, min(max_inter_city_rails_allowed, nr_of_connection_points) + 1) + number_of_out_rails = np.random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1) start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) for direction in range(4): connection_slots = np.arange(connections_per_direction[direction]) - int( @@ -703,16 +703,16 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ for connection_idx in range(connections_per_direction[direction]): if direction == 0: tmp_coordinates = ( - node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]) + node_position[0] - city_radius, node_position[1] + connection_slots[connection_idx]) if direction == 1: tmp_coordinates = ( - node_position[0] + connection_slots[connection_idx], node_position[1] + node_size) + node_position[0] + connection_slots[connection_idx], node_position[1] + city_radius) if direction == 2: tmp_coordinates = ( - node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]) + node_position[0] + city_radius, node_position[1] + connection_slots[connection_idx]) if direction == 3: tmp_coordinates = ( - node_position[0] + connection_slots[connection_idx], node_position[1] - node_size) + node_position[0] + connection_slots[connection_idx], node_position[1] - city_radius) connection_points_coordinates_inner[direction].append(tmp_coordinates) if connection_idx in range(start_idx, start_idx + number_of_out_rails + 1): connection_points_coordinates_outer[direction].append(tmp_coordinates) @@ -722,23 +722,22 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ connection_info.append(connections_per_direction) return inner_connection_points, outer_connection_points, connection_info, city_orientations - def _connect_cities(node_positions, connection_points, connection_info, city_cells, + def _connect_cities(city_positions: List[Tuple[int, int]], connection_points, city_cells: List[Tuple[int, int]], rail_trans, grid_map): """ Function to connect the different cities through their connection points :param city_positions: Positions of city centers :param connection_points: Boarder connection points of cities - :param connection_info: Number of connection points per direction NESW :param rail_trans: Transitions :param grid_map: Grid map :return: """ all_paths = [] - for current_node in np.arange(len(node_positions)): - neighbours = _closest_neigh_in_direction(current_node, node_positions) + for current_city_idx in np.arange(len(city_positions)): + neighbours = _closest_neighbour_in_direction(current_city_idx, city_positions) for out_direction in range(4): - for tmp_out_connection_point in connection_points[current_node][out_direction]: + for tmp_out_connection_point in connection_points[current_city_idx][out_direction]: # This only needs to be checked when entering this loop neighb_idx = neighbours[out_direction] if neighb_idx is None: @@ -879,10 +878,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ for cell in range(rails_to_fix_cnt): grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1])) - def _closest_neigh_in_direction(current_node, node_positions): + def _closest_neighbour_in_direction(current_city_idx: int, node_positions: List[Tuple[int, int]]): """ Returns indices of closest neighbours in every direction NESW - :param current_node: Index of node in city_positions list + :param current_city_idx: Index of node in city_positions list :param city_positions: list of all points being considered :return: list of index of closest neighbours in all directions """ @@ -890,11 +889,11 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ closest_neighb = [None for i in range(4)] for av_node in range(len(node_positions)): node_dist.append( - distance_on_rail(node_positions[current_node], node_positions[av_node], metric="Manhattan")) + distance_on_rail(node_positions[current_city_idx], node_positions[av_node], metric="Manhattan")) sorted_neighbours = np.argsort(node_dist) direction_set = 0 for neighb in sorted_neighbours[1:]: - direction_to_neighb = direction_to_point(node_positions[current_node], node_positions[neighb]) + direction_to_neighb = direction_to_point(node_positions[current_city_idx], node_positions[neighb]) if closest_neighb[direction_to_neighb] == None: closest_neighb[direction_to_neighb] = neighb direction_set += 1 @@ -909,10 +908,16 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ def _get_cells_in_city(center: Tuple[int, int], radius: int) -> List[Tuple[int, int]]: """ - Function to return all cells within a city - :param center: center coordinates of city - :param radius: radius of city (it is a square) - :return: returns flat list of all cell coordinates in the city + + Parameters + ---------- + center center coordinates of city + radius radius of city (it is a square) + + Returns + ------- + flat list of all cell coordinates in the city + """ x_range = np.arange(center[0] - radius, center[0] + radius + 1) y_range = np.arange(center[1] - radius, center[1] + radius + 1) @@ -923,23 +928,4 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ def _are_cities_overlapping(center_1, center_2, radius): return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius - def _track_number(city_position, city_orientation, position): - """ - FUnction that tells you if you are on even or uneven track number - :param city_position: - :param city_orientation: - :param position: - :return: - """ - if city_orientation % 2 == 0: - if city_position[1] - position[1] < 0: - return np.abs(city_position[1] - position[1]) % 2 - else: - return (np.abs(city_position[1] - position[1]) + 1) % 2 - else: - if city_position[0] - position[0] > 0: - return np.abs(city_position[0] - position[0]) % 2 - else: - return (np.abs(city_position[0] - position[0]) + 1) % 2 - return generator