Skip to content
Snippets Groups Projects
Commit dae7c424 authored by Erik Nygren's avatar Erik Nygren
Browse files

minor bugfixes

parent 4570db2d
No related branches found
No related tags found
No related merge requests found
...@@ -30,18 +30,18 @@ speed_ration_map = {1.: 0.25, # Fast passenger train ...@@ -30,18 +30,18 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50, env = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=30, # Number of cities in map (where train stations are) rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are)
num_intersections=5, # Number of intersections (no start / target) num_intersections=5, # Number of intersections (no start / target)
num_trainstations=20, # Number of possible start/targets on map num_trainstations=15, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center node_radius=2, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities/intersections num_neighb=4, # Number of connections to other cities/intersections
seed=15, # Random seed seed=15, # Random seed
realistic_mode=True, realistic_mode=True,
enhance_intersection=True enhance_intersection=True
), ),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=40, number_of_agents=10,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
......
...@@ -573,14 +573,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -573,14 +573,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
fraction = 0 fraction = 0
city_fraction = num_cities / tot_num_node city_fraction = num_cities / tot_num_node
step = np.gcd(num_intersections, num_cities) / tot_num_node step = np.gcd(num_intersections, num_cities) / tot_num_node
for node_idx in range(num_cities + num_intersections): for node_idx in range(num_cities + num_intersections):
to_close = True to_close = True
tries = 0 tries = 0
fraction = (fraction + step) % 1.
if not realistic_mode: if not realistic_mode:
while to_close: while to_close:
x_tmp = node_radius + np.random.randint(height - node_radius) x_tmp = node_radius + np.random.randint(height - node_radius)
...@@ -608,13 +609,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -608,13 +609,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
warnings.warn("Could not set nodes, please change initial parameters!!!!") warnings.warn("Could not set nodes, please change initial parameters!!!!")
break break
else: else:
fraction = (fraction + step) % 1.
x_tmp = x_positions[node_idx % nodes_per_row] x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row]
if len(city_positions) < num_cities and fraction < city_fraction: if len(city_positions) < num_cities and fraction < city_fraction:
city_positions.append((x_tmp, y_tmp)) city_positions.append((x_tmp, y_tmp))
else: else:
intersection_positions.append((x_tmp, y_tmp)) intersection_positions.append((x_tmp, y_tmp))
print(len(city_positions))
node_positions = city_positions + intersection_positions node_positions = city_positions + intersection_positions
# Chose node connection # Chose node connection
...@@ -634,13 +635,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -634,13 +635,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
# Priority city to intersection connections # Priority city to intersection connections
if current_node < num_cities and len(available_intersections) > 0: if False and current_node < num_cities and len(available_intersections) > 0:
available_nodes = available_intersections available_nodes = available_intersections
delete_idx = np.where(available_cities == current_node) delete_idx = np.where(available_cities == current_node)
available_cities = np.delete(available_cities, delete_idx, 0) available_cities = np.delete(available_cities, delete_idx, 0)
# Priority intersection to city connections # Priority intersection to city connections
elif current_node >= num_cities and len(available_cities) > 0: elif False and current_node >= num_cities and len(available_cities) > 0:
available_nodes = available_cities available_nodes = available_cities
delete_idx = np.where(available_intersections == current_node) delete_idx = np.where(available_intersections == current_node)
available_intersections = np.delete(available_intersections, delete_idx, 0) available_intersections = np.delete(available_intersections, delete_idx, 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment