Commit 9eb5bd55 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

sample city positions without retries

parent 13c731a9
......@@ -218,7 +218,7 @@ class SparseRailGen(RailGen):
'city_orientations' : orientation of cities
"""
if np_random is None:
np_random = RandomState()
np_random = RandomState(self.seed)
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
......@@ -240,6 +240,7 @@ class SparseRailGen(RailGen):
# and reduce the number of cities to build to avoid problems
max_feasible_cities = min(self.max_num_cities,
((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
if max_feasible_cities < 2:
# sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!")
......@@ -252,7 +253,6 @@ class SparseRailGen(RailGen):
else:
city_positions = self._generate_random_city_positions(max_feasible_cities, city_radius, width, height,
np_random=np_random)
# reduce num_cities if less were generated in random mode
num_cities = len(city_positions)
# If random generation failed just put the cities evenly
......@@ -261,12 +261,12 @@ class SparseRailGen(RailGen):
city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
num_cities = len(city_positions)
# Set up connection points for all cities
inner_connection_points, outer_connection_points, city_orientations, city_cells = \
self._generate_city_connection_points(
city_positions, city_radius, vector_field, rails_between_cities,
rail_pairs_in_city, np_random=np_random)
# import pdb; pdb.set_trace()
# Connect the cities through the connection points
inter_city_lines = self._connect_cities(city_positions, outer_connection_points, city_cells,
......@@ -315,27 +315,52 @@ class SparseRailGen(RailGen):
"""
city_positions: IntVector2DArray = []
for city_idx in range(num_cities):
too_close = True
tries = 0
while too_close:
row = city_radius + 1 + np_random.randint(height - 2 * (city_radius + 1))
col = city_radius + 1 + np_random.randint(width - 2 * (city_radius + 1))
too_close = False
# Check distance to cities
for city_pos in city_positions:
if self.__class__._are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1):
too_close = True
if not too_close:
city_positions.append((row, col))
tries += 1
if tries > 200:
warnings.warn(
"Could not set all required cities!")
break
# We track a grid of allowed indexes that can be sampled from for creating a new city
# This removes the old sampling method of retrying a random sample on failure
allowed_grid = np.zeros((height, width), dtype=np.uint8)
city_radius_pad1 = city_radius + 1
# Borders have to be not allowed from the start
# allowed_grid == 1 indicates locations that are allowed
allowed_grid[city_radius_pad1:-city_radius_pad1, city_radius_pad1:-city_radius_pad1] = 1
# This tracks the actual city borders
city_grid = np.ones((height, width), dtype=np.uint8)
for _ in range(num_cities):
allowed_indexes = np.where(allowed_grid == 1)
num_allowed_points = len(allowed_indexes[0])
if num_allowed_points == 0:
break
# Sample one of the allowed indexes
point_index = np_random.randint(num_allowed_points)
row = int(allowed_indexes[0][point_index])
col = int(allowed_indexes[1][point_index])
# # All points in the radius of the allowed point should be 1
assert np.all(city_grid[row - city_radius_pad1 : row + city_radius_pad1 + 1,
col - city_radius_pad1 : col + city_radius_pad1 + 1]), \
"Sampling Error, Cities overlap"
# Need to block city radius and extra margin so that next sampling is correct
# Clipping handles the case for negative indexes being generated
row_start = max(0, row - 2 * city_radius_pad1)
col_start = max(0, col - 2 * city_radius_pad1)
row_end = row + 2 * city_radius_pad1 + 1
col_end = col + 2 * city_radius_pad1 + 1
allowed_grid[row_start : row_end, col_start : col_end] = 0
# City grids is needed for redundant assertion check above
row_start = max(0, row - city_radius_pad1)
col_start = max(0, col - city_radius_pad1)
row_end = row + city_radius_pad1 + 1
col_end = col + city_radius_pad1 + 1
city_grid[row_start : row_end, col_start : col_end] = 0
city_positions.append((row, col))
created_cites = len(city_positions)
if created_cites < num_cities:
city_warning = f"Could not set all required cities! Created {created_cites}/{num_cities}"
warnings.warn(city_warning)
return city_positions
def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int
......@@ -360,7 +385,6 @@ class SparseRailGen(RailGen):
"""
aspect_ratio = height / width
# Compute max numbe of possible cities per row and col.
# Respect padding at edges of environment
# Respect padding between cities
......@@ -529,13 +553,13 @@ class SparseRailGen(RailGen):
grid4_directions = [Grid4TransitionsEnum.NORTH, Grid4TransitionsEnum.EAST, Grid4TransitionsEnum.SOUTH,
Grid4TransitionsEnum.WEST]
# import pdb; pdb.set_trace()
for current_city_idx in np.arange(len(city_positions)):
closest_neighbours = self._closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
for out_direction in grid4_directions:
neighbour_idx = self.get_closest_neighbour_for_direction(closest_neighbours, out_direction)
for city_out_connection_point in connection_points[current_city_idx][out_direction]:
min_connection_dist = np.inf
......@@ -547,14 +571,17 @@ class SparseRailGen(RailGen):
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
neighbour_connection_point = tmp_in_connection_point
new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point,
rail_trans, flip_start_node_trans=False,
flip_end_node_trans=False, respect_transition_validity=False,
avoid_rail=True,
forbidden_cells=city_cells)
if len(new_line) == 0:
warnings.warn("[WARNING] Unable to connect requested stations")
elif new_line[-1] != neighbour_connection_point or new_line[0] != city_out_connection_point:
warnings.warn("[WARNING] Unable to connect requested stations")
all_paths.extend(new_line)
import pdb; pdb.set_trace()
return all_paths
def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment