diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 8106585c098e54d275016b33ff6df4617d517e43..339945f3f886a4d53c4237b44bc68dd991b9154a 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,7 +37,7 @@ env = RailEnv(width=50, num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center - num_neighb=4, # Number of connections to other cities/intersections + num_neighb=2, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, enhance_intersection=False diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 8ba8b8638d6b8096bea1928d6008bc613c239760..6ad94e1fa49332cb18aa56a990225878f9582752 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -529,7 +529,8 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2, - num_neighb=3, grid_mode=False, enhance_intersection=False, seed=0) -> RailGenerator: + num_neighb=3, nr_inter_connections=2, grid_mode=False, enhance_intersection=False, + seed=0) -> RailGenerator: """ This is a level generator which generates complex sparse rail configurations @@ -599,7 +600,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 available_intersections = np.arange(_num_cities, nb_nodes) # Set up connection points - connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points=4) + connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points=8) # Start at some node current_node = np.random.randint(len(available_nodes_full)) node_stack = [current_node] @@ -651,22 +652,24 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 node_stack.append(neighb) dist_from_center = distance_on_rail(node_positions[current_node], node_positions[neighb]) + connection_distances = [] for tmp_out_connection_point in connection_points[current_node]: tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb]) - # Check if this connection node is on the city side facing the neighbour - if tmp_dist_to_node < dist_from_center: - min_connection_dist = np.inf - - # Find closes connection point - for tmp_in_connection_point in connection_points[neighb]: - tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point) - if tmp_dist < min_connection_dist: - min_connection_dist = tmp_dist - neighb_connection_point = tmp_in_connection_point - i += 1 - connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point) - boarder_connections.add((tmp_out_connection_point, current_node)) - boarder_connections.add((neighb_connection_point, neighb)) + connection_distances.append(tmp_dist_to_node) + possible_connection_points = argsort(connection_distances) + for sort_idx in possible_connection_points[:nr_inter_connections]: + # Find closes connection point + tmp_out_connection_point = connection_points[current_node][sort_idx] + min_connection_dist = np.inf + for tmp_in_connection_point in connection_points[neighb]: + tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point) + if tmp_dist < min_connection_dist: + min_connection_dist = tmp_dist + neighb_connection_point = tmp_in_connection_point + i += 1 + connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point) + boarder_connections.add((tmp_out_connection_point, current_node)) + boarder_connections.add((neighb_connection_point, neighb)) node_stack.pop(0) @@ -876,24 +879,34 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 connection_point_vector = [connection_per_direction, connection_per_direction, connection_per_direction, n_connection_points - 3 * connection_per_direction] connection_points_coordinates = [] + random_connection_slots = False for direction in range(4): - rnd_points = np.random.choice(np.arange(-node_size, node_size), size=connection_point_vector[direction], - replace=False) + if random_connection_slots: + connection_slots = np.random.choice(np.arange(-node_size, node_size), + size=connection_point_vector[direction], + replace=False) + else: + connection_slots = np.arange(connection_point_vector[direction]) - int( + connection_point_vector[direction] / 2) + print(connection_slots) for connection_idx in range(connection_point_vector[direction]): if direction == 0: connection_points_coordinates.append( - (node_position[0] - node_size, node_position[1] + rnd_points[connection_idx])) + (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx])) if direction == 1: connection_points_coordinates.append( - (node_position[0] + rnd_points[connection_idx], node_position[1] + node_size)) + (node_position[0] + connection_slots[connection_idx], node_position[1] + node_size)) if direction == 2: connection_points_coordinates.append( - (node_position[0] + node_size, node_position[1] + rnd_points[connection_idx])) + (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx])) if direction == 3: connection_points_coordinates.append( - (node_position[0] + rnd_points[connection_idx], node_position[1] - node_size)) + (node_position[0] + connection_slots[connection_idx], node_position[1] - node_size)) connection_points.append(connection_points_coordinates) return connection_points + def argsort(seq): + # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python + return sorted(range(len(seq)), key=seq.__getitem__) return generator