From 9eb5bd553fda6918a091d4381d3f9dfd2a2695a3 Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipamc77@gmail.com> Date: Tue, 7 Sep 2021 17:40:49 +0530 Subject: [PATCH] sample city positions without retries --- flatland/envs/rail_generators.py | 87 +++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 30 deletions(-) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 90dcfb36..9c65c6d6 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -218,7 +218,7 @@ class SparseRailGen(RailGen): 'city_orientations' : orientation of cities """ if np_random is None: - np_random = RandomState() + np_random = RandomState(self.seed) rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) @@ -240,6 +240,7 @@ class SparseRailGen(RailGen): # and reduce the number of cities to build to avoid problems max_feasible_cities = min(self.max_num_cities, ((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1)))) + if max_feasible_cities < 2: # sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.") raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!") @@ -252,7 +253,6 @@ class SparseRailGen(RailGen): else: city_positions = self._generate_random_city_positions(max_feasible_cities, city_radius, width, height, np_random=np_random) - # reduce num_cities if less were generated in random mode num_cities = len(city_positions) # If random generation failed just put the cities evenly @@ -261,12 +261,12 @@ class SparseRailGen(RailGen): city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width, height) num_cities = len(city_positions) - # Set up connection points for all cities inner_connection_points, outer_connection_points, city_orientations, city_cells = \ self._generate_city_connection_points( city_positions, city_radius, vector_field, rails_between_cities, rail_pairs_in_city, np_random=np_random) + # import pdb; pdb.set_trace() # Connect the cities through the connection points inter_city_lines = self._connect_cities(city_positions, outer_connection_points, city_cells, @@ -315,27 +315,52 @@ class SparseRailGen(RailGen): """ city_positions: IntVector2DArray = [] - for city_idx in range(num_cities): - too_close = True - tries = 0 - - while too_close: - row = city_radius + 1 + np_random.randint(height - 2 * (city_radius + 1)) - col = city_radius + 1 + np_random.randint(width - 2 * (city_radius + 1)) - too_close = False - # Check distance to cities - for city_pos in city_positions: - if self.__class__._are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1): - too_close = True - - if not too_close: - city_positions.append((row, col)) - - tries += 1 - if tries > 200: - warnings.warn( - "Could not set all required cities!") - break + + # We track a grid of allowed indexes that can be sampled from for creating a new city + # This removes the old sampling method of retrying a random sample on failure + allowed_grid = np.zeros((height, width), dtype=np.uint8) + city_radius_pad1 = city_radius + 1 + # Borders have to be not allowed from the start + # allowed_grid == 1 indicates locations that are allowed + allowed_grid[city_radius_pad1:-city_radius_pad1, city_radius_pad1:-city_radius_pad1] = 1 + # This tracks the actual city borders + city_grid = np.ones((height, width), dtype=np.uint8) + for _ in range(num_cities): + allowed_indexes = np.where(allowed_grid == 1) + num_allowed_points = len(allowed_indexes[0]) + if num_allowed_points == 0: + break + # Sample one of the allowed indexes + point_index = np_random.randint(num_allowed_points) + row = int(allowed_indexes[0][point_index]) + col = int(allowed_indexes[1][point_index]) + # # All points in the radius of the allowed point should be 1 + assert np.all(city_grid[row - city_radius_pad1 : row + city_radius_pad1 + 1, + col - city_radius_pad1 : col + city_radius_pad1 + 1]), \ + "Sampling Error, Cities overlap" + + # Need to block city radius and extra margin so that next sampling is correct + # Clipping handles the case for negative indexes being generated + row_start = max(0, row - 2 * city_radius_pad1) + col_start = max(0, col - 2 * city_radius_pad1) + row_end = row + 2 * city_radius_pad1 + 1 + col_end = col + 2 * city_radius_pad1 + 1 + + allowed_grid[row_start : row_end, col_start : col_end] = 0 + + # City grids is needed for redundant assertion check above + row_start = max(0, row - city_radius_pad1) + col_start = max(0, col - city_radius_pad1) + row_end = row + city_radius_pad1 + 1 + col_end = col + city_radius_pad1 + 1 + city_grid[row_start : row_end, col_start : col_end] = 0 + + city_positions.append((row, col)) + + created_cites = len(city_positions) + if created_cites < num_cities: + city_warning = f"Could not set all required cities! Created {created_cites}/{num_cities}" + warnings.warn(city_warning) return city_positions def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int @@ -360,7 +385,6 @@ class SparseRailGen(RailGen): """ aspect_ratio = height / width - # Compute max numbe of possible cities per row and col. # Respect padding at edges of environment # Respect padding between cities @@ -529,13 +553,13 @@ class SparseRailGen(RailGen): grid4_directions = [Grid4TransitionsEnum.NORTH, Grid4TransitionsEnum.EAST, Grid4TransitionsEnum.SOUTH, Grid4TransitionsEnum.WEST] - + # import pdb; pdb.set_trace() for current_city_idx in np.arange(len(city_positions)): closest_neighbours = self._closest_neighbour_in_grid4_directions(current_city_idx, city_positions) for out_direction in grid4_directions: - + neighbour_idx = self.get_closest_neighbour_for_direction(closest_neighbours, out_direction) - + for city_out_connection_point in connection_points[current_city_idx][out_direction]: min_connection_dist = np.inf @@ -547,14 +571,17 @@ class SparseRailGen(RailGen): if tmp_dist < min_connection_dist: min_connection_dist = tmp_dist neighbour_connection_point = tmp_in_connection_point - new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point, rail_trans, flip_start_node_trans=False, flip_end_node_trans=False, respect_transition_validity=False, avoid_rail=True, forbidden_cells=city_cells) + if len(new_line) == 0: + warnings.warn("[WARNING] Unable to connect requested stations") + elif new_line[-1] != neighbour_connection_point or new_line[0] != city_out_connection_point: + warnings.warn("[WARNING] Unable to connect requested stations") all_paths.extend(new_line) - + import pdb; pdb.set_trace() return all_paths def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction): -- GitLab