diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index e901da5bf9331a4041d267e8692ab408624a1a11..28c78cb6332e9084634cd363df70355b29b38f64 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,11 +37,11 @@ env = RailEnv(width=50, node_radius=4, # Proximity of stations to city center seed=0, # Random seed grid_mode=True, - max_connection_points_per_side=2, - max_nr_connection_directions=2 + max_inter_city_rails=2, + tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(), - number_of_agents=5, + number_of_agents=10, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index cf163e909941ee44b4da0af71c24534a7d1535d4..e381733faeb1c83ef72a3e19f10f17dce4d02966 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -533,8 +533,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, - grid_mode=False, max_connection_points_per_side=4, - max_nr_connection_directions=2, + grid_mode=False, max_inter_city_rails=4, tracks_in_city=4, seed=0) -> RailGenerator: """ This is a level generator which generates complex sparse rail configurations @@ -578,11 +577,11 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, # Set up connection points for all cities connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius, - max_connection_points_per_side, - max_nr_connection_directions) + tracks_in_city) # Connect the cities through the connection points - _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map) + _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails, + rail_trans, grid_map) # Build inner cities _build_inner_cities(node_positions, connection_points, rail_trans, grid_map) @@ -617,8 +616,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, tries = 0 while to_close: - x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1) - y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1) + x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius - 1) + y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius - 1) to_close = False # Check distance to nodes @@ -644,8 +643,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, nodes_ratio = height / width nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) - x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int) - y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int) + x_positions = np.linspace(node_radius + 1, height - node_radius - 2, nodes_per_row, dtype=int) + y_positions = np.linspace(node_radius + 1, width - node_radius - 2, nodes_per_col, dtype=int) node_positions = [] city_cells = [] for node_idx in range(nb_nodes): @@ -655,13 +654,11 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, city_cells.extend(_city_cells(node_positions[-1], node_radius)) return node_positions, city_cells - def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2, - max_nr_connection_directions=2): + def _generate_node_connection_points(node_positions, node_size, tracks_in_city=2): connection_points = [] connection_info = [] - max_nr_connection_directions = np.clip(max_nr_connection_directions, 0, 4) - if max_nr_connection_points > 2 * node_size + 1: - max_nr_connection_points = 2 * node_size + 1 + if tracks_in_city > 2 * node_size + 1: + tracks_in_city = 2 * node_size + 1 for node_position in node_positions: @@ -671,21 +668,17 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, neighb_dist.append(distance_on_rail(node_position, neighb_node)) closest_neighb_idx = argsort(neighb_dist) - # Store the directions to these neighbours + # Store the directions to these neighbours and orient city to face closest neighbour connection_sides_idx = [] idx = 1 - while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist): - current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) - if current_closest_direction not in connection_sides_idx: - connection_sides_idx.append(current_closest_direction) - idx += 1 + current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) + connection_sides_idx.append(current_closest_direction) + connection_sides_idx.append((current_closest_direction + 2) % 4) - # set the number of connection points for each direction + # set the number of tracks within a city, at least 2 tracks per city connections_per_direction = np.zeros(4, dtype=int) - + nr_of_connection_points = np.random.randint(2, tracks_in_city + 1) for idx in connection_sides_idx: - nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1) - connections_per_direction[idx] = nr_of_connection_points connection_points_coordinates = [[] for i in range(4)] @@ -710,7 +703,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, connection_info.append(connections_per_direction) return connection_points, connection_info - def _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map): + def _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails, + rail_trans, grid_map): """ Function to connect the different cities through their connection points :param node_positions: Positions of city centers @@ -723,6 +717,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, boarder_connections = set() for current_node in np.arange(len(node_positions)): direction = 0 + connected_to_city = [] for nbr_connection_points in connection_info[current_node]: if nbr_connection_points > 0: neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions) @@ -730,23 +725,35 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, direction += 1 continue - if neighb_idx is not None: - connection_distances = [] - for tmp_out_connection_point in connection_points[current_node][direction]: - # Find closest connection point - min_connection_dist = np.inf - all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in - sublist] - - for tmp_in_connection_point in all_neighb_connection_points: - tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point) - if tmp_dist < min_connection_dist: - min_connection_dist = tmp_dist - neighb_connection_point = tmp_in_connection_point - connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, - city_cells) - boarder_connections.add((tmp_out_connection_point, current_node)) - boarder_connections.add((neighb_connection_point, neighb_idx)) + if neighb_idx is None or neighb_idx in connected_to_city: + node_dist = [] + for av_node in node_positions: + node_dist.append(distance_on_rail(node_positions[current_node], av_node)) + i = 1 + neighbours = np.argsort(node_dist) + neighb_idx = neighbours[i] + while neighb_idx in connected_to_city: + i += 1 + neighb_idx = neighbours[i] + + connected_to_city.append(neighb_idx) + number_of_out_rails = np.random.randint(1, max_inter_city_rails + 1) + + for tmp_out_connection_point in connection_points[current_node][direction][:number_of_out_rails]: + # Find closest connection point + min_connection_dist = np.inf + all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in + sublist] + + for tmp_in_connection_point in all_neighb_connection_points: + tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point) + if tmp_dist < min_connection_dist: + min_connection_dist = tmp_dist + neighb_connection_point = tmp_in_connection_point + connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, + city_cells) + boarder_connections.add((tmp_out_connection_point, current_node)) + boarder_connections.add((neighb_connection_point, neighb_idx)) direction += 1 return boarder_connections