Skip to content
Snippets Groups Projects
Commit 622f787d authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed grid mode. Now if parameter is set too high it will not build those cities

parent 9d2d9575
No related branches found
No related tags found
No related merge requests found
......@@ -30,14 +30,14 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=20,
env = RailEnv(width=75,
height=75,
rail_generator=sparse_rail_generator(max_num_cities=50,
# Number of cities in map (where train stations are)
seed=1, # Random seed
grid_mode=False,
grid_mode=True,
max_rails_between_cities=3,
max_rails_in_city=8,
max_rails_in_city=4,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=50,
......
......@@ -648,13 +648,16 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int,
vector_field) -> (IntVector2DArray, IntVector2DArray):
aspect_ratio = height / width
cities_per_row = int(np.ceil(np.sqrt(num_cities * aspect_ratio)))
cities_per_col = int(np.ceil(num_cities / cities_per_row))
row_positions = np.linspace(city_radius + 1, height - city_radius - 2, cities_per_row, dtype=int)
col_positions = np.linspace(city_radius + 1, width - city_radius - 2, cities_per_col, dtype=int)
cities_per_row = min(int(np.ceil(np.sqrt(num_cities * aspect_ratio))),
int((height - 2 * (city_radius + 1)) / (2 * city_radius + 1)))
cities_per_col = min(int(np.ceil(num_cities / cities_per_row)),
int((width - 2 * (city_radius + 1)) / (2 * city_radius + 1)))
num_build_cities = min(num_cities, cities_per_col * cities_per_row)
row_positions = np.linspace(city_radius + 1, height - 2 * (city_radius + 1), cities_per_row, dtype=int)
col_positions = np.linspace(city_radius + 1, width - 2 * (city_radius + 1), cities_per_col, dtype=int)
city_positions = []
city_cells = []
for city_idx in range(num_cities):
for city_idx in range(num_build_cities):
row = row_positions[city_idx % cities_per_row]
col = col_positions[city_idx // cities_per_row]
city_positions.append((row, col))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment