diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index d3a7683ea7af3b93d01eb12ed405e2084c8526c7..7ea9371527952eb28fdd2656f641cc593d62f36d 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -768,13 +768,15 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, for current_city in range(len(node_positions)): all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in sublist] + city_boarder = _city_boarder(node_positions[current_city], node_radius) - for boarder in range(4): + random_boarders = np.random.choice(np.arange(4), 4, False) + # TODO: Only look at the relevant boarders (Only two at the moment) + for boarder in random_boarders: for source in inner_connection_points[current_city][boarder]: - for other_boarder in range(4): + for other_boarder in random_boarders: if boarder != other_boarder and len(inner_connection_points[current_city][other_boarder]) > 0: for target in inner_connection_points[current_city][other_boarder]: - city_boarder = _city_boarder(node_positions[current_city], node_radius) current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) if target in all_outer_connection_points and source in \ all_outer_connection_points and len(through_path_cells[current_city]) < 1: