From 014f2d747e1e1d8d3e31bd33d0894b7ab79c36bd Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sun, 18 Aug 2019 07:49:43 -0400 Subject: [PATCH] priority to connect to intersection instead of other city. minor bug fixes in node connecting algorithm --- flatland/envs/generators.py | 6 ++++++ tests/test_flatland_env_sparse_rail_generator.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 27508c22..b571a003 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -874,11 +874,16 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) + if current_node < num_cities and len(available_intersections) > 0: available_nodes = available_intersections + delete_idx = np.where(available_cities == current_node) + available_cities = np.delete(available_cities, delete_idx, 0) elif len(available_intersections) > 0: available_nodes = available_cities + delete_idx = np.where(available_intersections == current_node) + available_intersections = np.delete(available_intersections, delete_idx, 0) else: available_nodes = available_nodes_full @@ -892,6 +897,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation if len(available_nodes) >= num_neighb: connected_neighb_idx = available_nodes[ 0:num_neighb] # np.random.choice(available_nodes, num_neighb, replace=False) + print(current_node, "-->", connected_neighb_idx) else: connected_neighb_idx = available_nodes diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index b20754be..48d4c577 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -30,12 +30,12 @@ def test_sparse_rail_generator(): num_trainstations=10, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center - num_neighb=2, # Number of connections to other cities - seed=15, # Random seed + num_neighb=1, # Number of connections to other cities + seed=5, # Random seed ), number_of_agents=0, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - time.sleep(5) + time.sleep(19) -- GitLab