diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index e381733faeb1c83ef72a3e19f10f17dce4d02966..b1423250cf492a18ed625b5c8a9a6739d7eea2a1 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -580,14 +580,16 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, tracks_in_city) # Connect the cities through the connection points - _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails, - rail_trans, grid_map) + outer_connection_points = _connect_cities(node_positions, connection_points, connection_info, city_cells, + max_inter_city_rails, + rail_trans, grid_map) # Build inner cities - _build_inner_cities(node_positions, connection_points, rail_trans, grid_map) + through_tracks = _build_inner_cities(node_positions, connection_points, outer_connection_points, rail_trans, + grid_map) # Populate cities - train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, grid_map) + train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, through_tracks, grid_map) # Adjust the number of agents if you could not build enough trainstations if num_agents > built_num_trainstation: @@ -714,7 +716,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, :param grid_map: Grid map :return: """ - boarder_connections = set() + boarder_connections = [[] for i in range(len(node_positions))] for current_node in np.arange(len(node_positions)): direction = 0 connected_to_city = [] @@ -752,12 +754,14 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, neighb_connection_point = tmp_in_connection_point connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, city_cells) - boarder_connections.add((tmp_out_connection_point, current_node)) - boarder_connections.add((neighb_connection_point, neighb_idx)) + if tmp_out_connection_point not in boarder_connections[current_node]: + boarder_connections[current_node].append(tmp_out_connection_point) + if neighb_connection_point not in boarder_connections[neighb_idx]: + boarder_connections[neighb_idx].append(neighb_connection_point) direction += 1 return boarder_connections - def _build_inner_cities(node_positions, connection_points, rail_trans, grid_map): + def _build_inner_cities(node_positions, connection_points, outer_connection_points, rail_trans, grid_map): """ Builds inner city tracks. This current version connects all incoming connections to all outgoing connections :param node_positions: @@ -766,6 +770,7 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, :param grid_map: :return: """ + through_path_cells = [[] for i in range(len(node_positions))] for current_city in range(len(node_positions)): for boarder in range(4): for source in connection_points[current_city][boarder]: @@ -773,13 +778,16 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, if boarder != other_boarder and len(connection_points[current_city][other_boarder]) > 0: for target in connection_points[current_city][other_boarder]: city_boarder = _city_boarder(node_positions[current_city], node_radius) - connect_cities(rail_trans, grid_map, source, target, city_boarder) + current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + if target in outer_connection_points[current_city] and source in \ + outer_connection_points[current_city]: + through_path_cells[current_city].extend(current_track) else: continue - return + return through_path_cells - def _set_trainstation_positions(node_positions, grid_map): + def _set_trainstation_positions(node_positions, through_tracks, grid_map): """ :param node_positions: @@ -791,6 +799,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, built_num_trainstations = 0 for current_city in range(len(node_positions)): for possible_location in _city_cells(node_positions[current_city], node_radius - 1): + if possible_location in through_tracks[current_city]: + continue cell_type = grid_map.get_full_transitions(*possible_location) nbits = 0 while cell_type > 0: