From 743467833ea87162bce7289db86e779749013ec2 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Thu, 3 Oct 2019 08:48:20 +0200 Subject: [PATCH] refactor _city_cells to not use for loops --- flatland/envs/rail_generators.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index b132a5ca..e59c782e 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -641,7 +641,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ if not to_close: node_positions.append((x_tmp, y_tmp)) - city_cells.extend(_city_cells(node_positions[-1], city_radius)) + city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius)) tries += 1 if tries > 200: @@ -664,7 +664,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] node_positions.append((x_tmp, y_tmp)) - city_cells.extend(_city_cells(node_positions[-1], city_radius)) + city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius)) return node_positions, city_cells def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2): @@ -781,8 +781,6 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ for current_city in range(len(node_positions)): all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in sublist] - city_boarder = _city_boarder(node_positions[current_city], node_radius) - # This part only works if we have keep same number of connection points for both directions # Also only works with two connection direction at each city for i in range(4): @@ -911,30 +909,21 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python return sorted(range(len(seq)), key=seq.__getitem__) - def _city_cells(center, radius): + def _get_cells_in_city(center: Tuple[int], radius: int) -> List[Tuple[int, int]]: """ Function to return all cells within a city :param center: center coordinates of city :param radius: radius of city (it is a square) :return: returns flat list of all cell coordinates in the city """ - city_cells = [] - for x in range(-radius, radius + 1): - for y in range(-radius, radius + 1): - city_cells.append((center[0] + x, center[1] + y)) - - return city_cells - - def _city_boarder(center, radius): - city_boarder = [] - for x in range(-radius, radius + 1): - for y in range(-radius, radius + 1): - if abs(x) == radius or abs(y) == radius: - city_boarder.append((center[0] + x, center[1] + y)) - return city_boarder + x_range = np.arange(center[0] - radius, center[0] + radius + 1) + y_range = np.arange(center[1] - radius, center[1] + radius + 1) + x_values = np.repeat(x_range, len(y_range)) + y_values = np.tile(y_range, len(x_range)) + return list(zip(x_values, y_values)) def _city_overlap(center_1, center_2, radius): - return (np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius) + return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius def _track_number(city_position, city_orientation, position): """ -- GitLab