Skip to content
Snippets Groups Projects
Commit ee52af87 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed grid mode. Now if parameter is set too high it will not build those cities

parent 622f787d
No related branches found
No related tags found
No related merge requests found
...@@ -30,9 +30,9 @@ speed_ration_map = {1.: 0.25, # Fast passenger train ...@@ -30,9 +30,9 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 3.: 0.25, # Slow commuter train 1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train 1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=75, env = RailEnv(width=30,
height=75, height=30,
rail_generator=sparse_rail_generator(max_num_cities=50, rail_generator=sparse_rail_generator(max_num_cities=4,
# Number of cities in map (where train stations are) # Number of cities in map (where train stations are)
seed=1, # Random seed seed=1, # Random seed
grid_mode=True, grid_mode=True,
......
...@@ -649,9 +649,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ ...@@ -649,9 +649,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
vector_field) -> (IntVector2DArray, IntVector2DArray): vector_field) -> (IntVector2DArray, IntVector2DArray):
aspect_ratio = height / width aspect_ratio = height / width
cities_per_row = min(int(np.ceil(np.sqrt(num_cities * aspect_ratio))), cities_per_row = min(int(np.ceil(np.sqrt(num_cities * aspect_ratio))),
int((height - 2 * (city_radius + 1)) / (2 * city_radius + 1))) int((height - 2) / (2 * city_radius + 1)))
cities_per_col = min(int(np.ceil(num_cities / cities_per_row)), cities_per_col = min(int(np.ceil(num_cities / cities_per_row)),
int((width - 2 * (city_radius + 1)) / (2 * city_radius + 1))) int((width - 2) / (2 * city_radius + 1)))
num_build_cities = min(num_cities, cities_per_col * cities_per_row) num_build_cities = min(num_cities, cities_per_col * cities_per_row)
row_positions = np.linspace(city_radius + 1, height - 2 * (city_radius + 1), cities_per_row, dtype=int) row_positions = np.linspace(city_radius + 1, height - 2 * (city_radius + 1), cities_per_row, dtype=int)
col_positions = np.linspace(city_radius + 1, width - 2 * (city_radius + 1), cities_per_col, dtype=int) col_positions = np.linspace(city_radius + 1, width - 2 * (city_radius + 1), cities_per_col, dtype=int)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment