diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 9d1bcfeca0ab06932d164fdbeea5fa69103a0888..4463b1e3a0a457b895cbd8253a7ea87ae2633318 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -869,7 +869,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation available_intersections = np.arange(num_cities, num_cities + num_intersections) current_node = 0 node_stack = [current_node] - + allowed_connections = num_neighb while len(node_stack) > 0: current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) @@ -883,7 +883,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation elif len(available_intersections) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) - available_intersections = np.delete(available_intersections, delete_idx, 0) else: available_nodes = available_nodes_full @@ -894,12 +893,15 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation available_nodes = available_nodes[np.argsort(node_dist)] # Set number of neighboring nodes - if len(available_nodes) >= num_neighb: - connected_neighb_idx = available_nodes[ - 0:num_neighb] # np.random.choice(available_nodes, num_neighb, replace=False) + + print(current_node, allowed_connections) + if len(available_nodes) >= allowed_connections: + connected_neighb_idx = available_nodes[:allowed_connections] else: connected_neighb_idx = available_nodes + if current_node == 0: + allowed_connections -= 1 # Connect to the neighboring nodes for neighb in connected_neighb_idx: if neighb not in node_stack: diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 48d4c5776c65fb3d14c174c0f7ddf14f5ff9a76a..f49893aeb78cbb09c45993738b0466efe442e0db 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -25,12 +25,12 @@ def test_realistic_rail_generator(): def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map - num_intersections=2, # Number of interesections in map - num_trainstations=10, # Number of possible start/targets on map + rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map + num_intersections=3, # Number of interesections in map + num_trainstations=5, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center - num_neighb=1, # Number of connections to other cities + num_neighb=3, # Number of connections to other cities seed=5, # Random seed ), number_of_agents=0, @@ -38,4 +38,4 @@ def test_sparse_rail_generator(): # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - time.sleep(19) + time.sleep(10)