Skip to content
Snippets Groups Projects
Commit 4570db2d authored by Erik Nygren's avatar Erik Nygren
Browse files

dixing level generator city distribution

parent ac9dae25
No related branches found
No related tags found
No related merge requests found
...@@ -30,18 +30,18 @@ speed_ration_map = {1.: 0.25, # Fast passenger train ...@@ -30,18 +30,18 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50, env = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are) rail_generator=sparse_rail_generator(num_cities=30, # Number of cities in map (where train stations are)
num_intersections=15, # Number of intersections (no start / target) num_intersections=5, # Number of intersections (no start / target)
num_trainstations=50, # Number of possible start/targets on map num_trainstations=20, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes min_node_dist=3, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center node_radius=2, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities/intersections num_neighb=3, # Number of connections to other cities/intersections
seed=15, # Random seed seed=15, # Random seed
realistic_mode=True, realistic_mode=True,
enhance_intersection=True enhance_intersection=True
), ),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=20, number_of_agents=40,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
......
...@@ -573,9 +573,14 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -573,9 +573,14 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
fraction = 0
city_fraction = num_cities / tot_num_node
step = np.gcd(num_intersections, num_cities) / tot_num_node
for node_idx in range(num_cities + num_intersections): for node_idx in range(num_cities + num_intersections):
to_close = True to_close = True
tries = 0 tries = 0
fraction = (fraction + step) % 1.
if not realistic_mode: if not realistic_mode:
while to_close: while to_close:
x_tmp = node_radius + np.random.randint(height - node_radius) x_tmp = node_radius + np.random.randint(height - node_radius)
...@@ -587,7 +592,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -587,7 +592,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True to_close = True
# CHeck distance to intersections # Check distance to intersections
for node_pos in intersection_positions: for node_pos in intersection_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True to_close = True
...@@ -605,11 +610,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -605,11 +610,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
else: else:
x_tmp = x_positions[node_idx % nodes_per_row] x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row]
if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0: if len(city_positions) < num_cities and fraction < city_fraction:
city_positions.append((x_tmp, y_tmp)) city_positions.append((x_tmp, y_tmp))
else: else:
intersection_positions.append((x_tmp, y_tmp)) intersection_positions.append((x_tmp, y_tmp))
print(len(city_positions))
node_positions = city_positions + intersection_positions node_positions = city_positions + intersection_positions
# Chose node connection # Chose node connection
...@@ -627,6 +632,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -627,6 +632,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
current_node = node_stack[0] current_node = node_stack[0]
delete_idx = np.where(available_nodes_full == current_node) delete_idx = np.where(available_nodes_full == current_node)
available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
# Priority city to intersection connections # Priority city to intersection connections
if current_node < num_cities and len(available_intersections) > 0: if current_node < num_cities and len(available_intersections) > 0:
available_nodes = available_intersections available_nodes = available_intersections
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment