diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 0b3f3eb02dc77efc03740f97888dbfc9cabfa4e7..087c6299966f5d1e04e6d21809afbd73999be6c1 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -995,20 +995,28 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 nodes_ratio = height / width nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio))) nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) - x_positions = np.linspace(2, height - 2, nodes_per_row, dtype=int) - y_positions = np.linspace(2, width - 2, nodes_per_col, dtype=int) + x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) + y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 if not realistic_mode: while to_close: - x_tmp = 1 + np.random.randint(height - 2) - y_tmp = 1 + np.random.randint(width - 2) + x_tmp = node_radius + np.random.randint(height - node_radius) + y_tmp = node_radius + np.random.randint(width - node_radius) to_close = False - for node_pos in node_positions: + + # Check distance to cities + for node_pos in city_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + # CHeck distance to intersections + for node_pos in intersection_positions: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: to_close = True + if not to_close: node_positions.append((x_tmp, y_tmp)) if node_idx < num_cities: @@ -1027,30 +1035,39 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 else: intersection_positions.append((x_tmp, y_tmp)) - if realistic_mode: - node_positions = city_positions + intersection_positions + node_positions = city_positions + intersection_positions # Chose node connection + # Set up list of available nodes to connect to available_nodes_full = np.arange(num_cities + num_intersections) available_cities = np.arange(num_cities) available_intersections = np.arange(num_cities, num_cities + num_intersections) - current_node = 0 + + # Keep track of number of incoming connection + incoming_connections = np.zeros(num_intersections + num_cities) + + # Start at some node + current_node = np.random.randint(len(available_nodes_full)) node_stack = [current_node] allowed_connections = num_neighb while len(node_stack) > 0: current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) - + filled_nodes = np.where(incoming_connections >= num_neighb) + # Priority city to intersection connections if current_node < num_cities and len(available_intersections) > 0: available_nodes = available_intersections delete_idx = np.where(available_cities == current_node) - available_cities = np.delete(available_cities, delete_idx, 0) + + # Priority intersection to city connections elif current_node >= num_cities and len(available_cities) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) available_intersections = np.delete(available_intersections, delete_idx, 0) + + # If no options possible connect to whatever node is still available else: available_nodes = available_nodes_full @@ -1074,6 +1091,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if neighb not in node_stack: node_stack.append(neighb) connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) + incoming_connections[neighb] += 1 + incoming_connections[current_node] += 1 node_stack.pop(0) # Place train stations close to the node