diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index c2f4f43809ffe3294e350dc16d86df13a417b984..a42146c4a2411dbd00707fa1d850ec0bde46a487 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -598,16 +598,18 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, print("City connection time", time.time() - city_connection_time) # Build inner cities city_build_time = time.time() - through_tracks = _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, - node_radius, - rail_trans, - grid_map) + through_tracks, free_tracks = _build_inner_cities(node_positions, inner_connection_points, + outer_connection_points, + node_radius, + rail_trans, + grid_map) print("City build time", time.time() - city_build_time) # Populate cities train_station_time = time.time() train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, city_orientations, through_tracks, + free_tracks, node_radius, grid_map) print("Trainstation placing time", time.time() - train_station_time) @@ -799,6 +801,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, :return: Returns the cells of the through path which cannot be occupied by trainstations """ through_path_cells = [[] for i in range(len(node_positions))] + free_tracks = [[] for i in range(len(node_positions))] for current_city in range(len(node_positions)): all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in sublist] @@ -821,25 +824,18 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Connect parallel tracks for track_id in range(len(inner_connection_points[current_city][boarder])): - if track_id % 2 == 0: - source = inner_connection_points[current_city][boarder][track_id] - target = inner_connection_points[current_city][opposite_boarder][track_id] - current_track = connect_straigt_line(rail_trans, grid_map, source, target, False) - if target in all_outer_connection_points and source in \ - all_outer_connection_points and len(through_path_cells[current_city]) < 1: - through_path_cells[current_city].extend(current_track) + source = inner_connection_points[current_city][boarder][track_id] + target = inner_connection_points[current_city][opposite_boarder][track_id] + current_track = connect_straigt_line(rail_trans, grid_map, source, target, False) + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) else: - source = inner_connection_points[current_city][opposite_boarder][track_id] - target = inner_connection_points[current_city][boarder][track_id] + free_tracks[current_city].append(current_track) + return through_path_cells, free_tracks - current_track = connect_straigt_line(rail_trans, grid_map, source, target, False) - if target in all_outer_connection_points and source in \ - all_outer_connection_points and len(through_path_cells[current_city]) < 1: - through_path_cells[current_city].extend(current_track) - - return through_path_cells - - def _set_trainstation_positions(node_positions, city_orientations, through_tracks, node_radius, grid_map): + def _set_trainstation_positions(node_positions, city_orientations, through_tracks, free_tracks, node_radius, + grid_map): """ :param node_positions: @@ -848,23 +844,26 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, """ nb_nodes = len(node_positions) train_stations = [[] for i in range(nb_nodes)] - + left = 0 + right = 0 built_num_trainstations = 0 for current_city in range(len(node_positions)): - for possible_location in _city_cells(node_positions[current_city], node_radius - 1): - if possible_location in through_tracks[current_city]: - continue - cell_type = grid_map.get_full_transitions(*possible_location) - nbits = 0 - while cell_type > 0: - nbits += (cell_type & 1) - cell_type = cell_type >> 1 - if 1 <= nbits <= 2: - built_num_trainstations += 1 - track_nbr = _track_number(node_positions[current_city], city_orientations[current_city], - possible_location) - train_stations[current_city].append((possible_location, track_nbr)) - + for track_nbr in range(len(free_tracks[current_city])): + for possible_location in free_tracks[current_city][track_nbr]: + # Only build trainstation on non diverging elements + cell_type = grid_map.get_full_transitions(*possible_location) + nbits = 0 + while cell_type > 0: + nbits += (cell_type & 1) + cell_type = cell_type >> 1 + if 1 <= nbits <= 2: + built_num_trainstations += 1 + if track_nbr % 2 == 0: + left += 1 + else: + right += 1 + train_stations[current_city].append((possible_location, track_nbr)) + print(left, right) return train_stations, built_num_trainstations def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): @@ -1010,7 +1009,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, else: return (np.abs(city_position[1] - position[1]) + 1) % 2 else: - if city_position[0] - position[0] < 0: + if city_position[0] - position[0] > 0: return np.abs(city_position[0] - position[0]) % 2 else: return (np.abs(city_position[0] - position[0]) + 1) % 2