Skip to content
Snippets Groups Projects
Commit 4928e906 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated number of node selection algorithm

parent 66b7f776
No related branches found
No related tags found
No related merge requests found
......@@ -869,7 +869,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
available_intersections = np.arange(num_cities, num_cities + num_intersections)
current_node = 0
node_stack = [current_node]
allowed_connections = num_neighb
while len(node_stack) > 0:
current_node = node_stack[0]
delete_idx = np.where(available_nodes_full == current_node)
......@@ -883,7 +883,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
elif len(available_intersections) > 0:
available_nodes = available_cities
delete_idx = np.where(available_intersections == current_node)
available_intersections = np.delete(available_intersections, delete_idx, 0)
else:
available_nodes = available_nodes_full
......@@ -894,12 +893,15 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
available_nodes = available_nodes[np.argsort(node_dist)]
# Set number of neighboring nodes
if len(available_nodes) >= num_neighb:
connected_neighb_idx = available_nodes[
0:num_neighb] # np.random.choice(available_nodes, num_neighb, replace=False)
print(current_node, allowed_connections)
if len(available_nodes) >= allowed_connections:
connected_neighb_idx = available_nodes[:allowed_connections]
else:
connected_neighb_idx = available_nodes
if current_node == 0:
allowed_connections -= 1
# Connect to the neighboring nodes
for neighb in connected_neighb_idx:
if neighb not in node_stack:
......
......@@ -25,12 +25,12 @@ def test_realistic_rail_generator():
def test_sparse_rail_generator():
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map
num_intersections=2, # Number of interesections in map
num_trainstations=10, # Number of possible start/targets on map
rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map
num_intersections=3, # Number of interesections in map
num_trainstations=5, # Number of possible start/targets on map
min_node_dist=10, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=1, # Number of connections to other cities
num_neighb=3, # Number of connections to other cities
seed=5, # Random seed
),
number_of_agents=0,
......@@ -38,4 +38,4 @@ def test_sparse_rail_generator():
# reset to initialize agents_static
env_renderer = RenderTool(env, gl="PILSVG", )
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
time.sleep(19)
time.sleep(10)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment