From 44635204283dd5a026d067d1168d50f8742ade8c Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 25 Sep 2019 14:27:26 -0400 Subject: [PATCH] fixed bug in randomly placed cities --- examples/flatland_2_0_example.py | 4 ++-- flatland/envs/rail_generators.py | 31 ++++++++++++++++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 339945f3..2334450c 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -35,11 +35,11 @@ env = RailEnv(width=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) num_trainstations=15, # Number of possible start/targets on map - min_node_dist=3, # Minimal distance of nodes + min_node_dist=10, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center num_neighb=2, # Number of connections to other cities/intersections seed=15, # Random seed - grid_mode=True, + grid_mode=False, enhance_intersection=False ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index f3d7b889..c0721f62 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -604,26 +604,30 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Start at some node current_node = np.random.randint(len(available_nodes_full)) node_stack = [current_node] + open_nodes = np.copy(available_nodes_full) allowed_connections = num_neighb - first_node = True i = 0 boarder_connections = set() - while len(node_stack) > 0: - current_node = node_stack[0] - delete_idx = np.where(available_nodes_full == current_node) - available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) + while len(open_nodes) > 0: + if len(node_stack) > 0: + current_node = node_stack[0] + else: + current_node = np.random.choice(open_nodes) + node_stack.append(current_node) + delete_idx = np.where(open_nodes == current_node) + open_nodes = np.delete(open_nodes, delete_idx, 0) # Priority city to intersection connections if current_node < _num_cities and len(available_intersections) > 0: available_nodes = available_intersections delete_idx = np.where(available_cities == current_node) - available_cities = np.delete(available_cities, delete_idx, 0) + # available_cities = np.delete(available_cities, delete_idx, 0) # Priority intersection to city connections elif current_node >= _num_cities and len(available_cities) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) - available_intersections = np.delete(available_intersections, delete_idx, 0) + # available_intersections = np.delete(available_intersections, delete_idx, 0) # If no options possible connect to whatever node is still available else: @@ -637,18 +641,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Set number of neighboring nodes if len(available_nodes) >= allowed_connections: - connected_neighb_idx = available_nodes[:allowed_connections] + connected_neighb_idx = available_nodes[1:allowed_connections + 1] else: connected_neighb_idx = available_nodes + print(current_node, connected_neighb_idx) - # Less connections for subsequent nodes - if first_node: - allowed_connections -= 1 - first_node = False # Connect to the neighboring nodes for neighb in connected_neighb_idx: - if neighb not in node_stack: + if neighb not in node_stack and neighb in open_nodes: node_stack.append(neighb) dist_from_center = distance_on_rail(node_positions[current_node], node_positions[neighb]) @@ -824,8 +825,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 tries = 0 while to_close: - x_tmp = node_radius + np.random.randint(height - node_radius - 1) - y_tmp = node_radius + np.random.randint(width - node_radius - 1) + x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1) + y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1) to_close = False # Check distance to cities -- GitLab