From 4928e9063688efc3569deb29c9d907d4ef46a57e Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sun, 18 Aug 2019 08:20:12 -0400 Subject: [PATCH] updated number of node selection algorithm --- flatland/envs/generators.py | 12 +++++++----- tests/test_flatland_env_sparse_rail_generator.py | 10 +++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 9d1bcfec..4463b1e3 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -869,7 +869,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation available_intersections = np.arange(num_cities, num_cities + num_intersections) current_node = 0 node_stack = [current_node] - + allowed_connections = num_neighb while len(node_stack) > 0: current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) @@ -883,7 +883,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation elif len(available_intersections) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) - available_intersections = np.delete(available_intersections, delete_idx, 0) else: available_nodes = available_nodes_full @@ -894,12 +893,15 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation available_nodes = available_nodes[np.argsort(node_dist)] # Set number of neighboring nodes - if len(available_nodes) >= num_neighb: - connected_neighb_idx = available_nodes[ - 0:num_neighb] # np.random.choice(available_nodes, num_neighb, replace=False) + + print(current_node, allowed_connections) + if len(available_nodes) >= allowed_connections: + connected_neighb_idx = available_nodes[:allowed_connections] else: connected_neighb_idx = available_nodes + if current_node == 0: + allowed_connections -= 1 # Connect to the neighboring nodes for neighb in connected_neighb_idx: if neighb not in node_stack: diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 48d4c577..f49893ae 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -25,12 +25,12 @@ def test_realistic_rail_generator(): def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map - num_intersections=2, # Number of interesections in map - num_trainstations=10, # Number of possible start/targets on map + rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map + num_intersections=3, # Number of interesections in map + num_trainstations=5, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center - num_neighb=1, # Number of connections to other cities + num_neighb=3, # Number of connections to other cities seed=5, # Random seed ), number_of_agents=0, @@ -38,4 +38,4 @@ def test_sparse_rail_generator(): # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - time.sleep(19) + time.sleep(10) -- GitLab