diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 12f844225d321e3c43026bebbe276af56a8caaba..69959ca8c056465260afc987b1100d0cfb316285 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,13 +34,11 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_trainstations=50, # Number of possible start/targets on map - min_node_dist=30, # Minimal distance of nodes + min_node_dist=15, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, - nr_parallel_tracks=1, - connection_points_per_side=2, + connection_points_per_side=3, max_nr_connection_directions=4, ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index ee218dc6827644eca4b05ce9f1d8b09bc49f9696..178da0c71cbe06656e2002f7ff7dfba36eba4c58 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -533,7 +533,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2, - num_neighb=3, nr_parallel_tracks=2, grid_mode=False, connection_points_per_side=4, + grid_mode=False, connection_points_per_side=4, max_nr_connection_directions=2, seed=0) -> RailGenerator: """ @@ -563,9 +563,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n rail_array = grid_map.grid rail_array.fill(0) np.random.seed(seed + num_resets) - tracks_between_cities = nr_parallel_tracks - if nr_parallel_tracks >= connection_points_per_side: - tracks_between_cities = connection_points_per_side # Generate a set of nodes for the sparse network # Try to connect cities to nodes first @@ -605,7 +602,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n max_nr_connection_directions) # Connect the cities through the connection points - _connect_cities(node_positions, connection_points, connection_info, tracks_between_cities, rail_trans, grid_map) + _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map) # Build inner cities train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map) @@ -699,6 +696,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Store the directions to these neighbours connection_sides_idx = [] + for idx in range(1, min(len(neighb_dist) - 1, max_nr_connection_directions) + 1): connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]])) @@ -706,7 +704,9 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n connections_per_direction = np.zeros(4, dtype=int) for idx in connection_sides_idx: - connections_per_direction[idx] = max_nr_connection_points + nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1) + + connections_per_direction[idx] = nr_of_connection_points connection_points_coordinates = [] for direction in range(4): @@ -730,7 +730,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n connection_info.append(connections_per_direction) return connection_points, connection_info - def _connect_cities(node_positions, connection_points, connection_info, tracks_between_cities, rail_trans, + def _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map): """ Function to connect the different cities through their connection points @@ -759,7 +759,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb_idx]) connection_distances.append(tmp_dist_to_node) possible_connection_points = argsort(connection_distances) - for sort_idx in possible_connection_points[:tracks_between_cities]: + for sort_idx in possible_connection_points[:connection_info[current_node][direction]]: # Find closest connection point tmp_out_connection_point = connection_points[current_node][sort_idx] min_connection_dist = np.inf