diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 89806bbb45af6614b50f2160e061a49434f45687..e1128fc7d0b07c1759920d430ec3547750193d73 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -776,6 +776,9 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, """ through_path_cells = [[] for i in range(len(node_positions))] for current_city in range(len(node_positions)): + all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in + sublist] + for boarder in range(4): for source in connection_points[current_city][boarder]: for other_boarder in range(4): @@ -783,8 +786,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, for target in connection_points[current_city][other_boarder]: city_boarder = _city_boarder(node_positions[current_city], node_radius) current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) - if target in outer_connection_points[current_city] and source in \ - outer_connection_points[current_city] and len(through_path_cells[current_city]) < 1: + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: through_path_cells[current_city].extend(current_track) else: continue