diff --git a/flatland/envs/rail_generators_city_generator.py b/flatland/envs/rail_generators_city_generator.py index 079809dadb69713d5b3d7950d186b6f8da322aa6..b4d854bbde0ef3d5b1ed810ef809b8514187a8ab 100644 --- a/flatland/envs/rail_generators_city_generator.py +++ b/flatland/envs/rail_generators_city_generator.py @@ -84,7 +84,7 @@ def city_generator(num_cities: int = 5, # noinspection PyTypeChecker def create_stations_from_city_locations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - generate_city_locations: IntVector2DArray, + generate_city_locations: IntVector2DArrayArray, intern_max_number_of_station_tracks: int) -> (IntVector2DArray, IntVector2DArray, IntVector2DArray, @@ -144,7 +144,7 @@ def city_generator(num_cities: int = 5, # noinspection PyTypeChecker def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - station_tracks: IntVector2DArray, + station_tracks: IntVector2DArrayArray, nodes_added: IntVector2DArray, intern_nbr_of_switches_per_station_track: int) -> IntVector2DArray: @@ -429,44 +429,40 @@ def city_generator(num_cities: int = 5, # ---------------------------------------------------------------------------------- # generate city topology - nodes_added, train_stations, s_nodes, e_nodes, station_tracks = \ + nodes_added, train_stations_slots, s_nodes, e_nodes, station_tracks = \ create_stations_from_city_locations(rail_trans, grid_map, generate_city_locations, intern_max_number_of_station_tracks) # build switches - # TODO remove true/false block - if True: - create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added, - intern_nbr_of_switches_per_station_track) + create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added, + intern_nbr_of_switches_per_station_track) # ---------------------------------------------------------------------------------- # connect stations - # TODO remove true/false block - if True: - if do_random_connect_stations: - connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, - intern_connect_max_nbr_of_shortes_city) - else: - connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, - intern_connect_max_nbr_of_shortes_city) + if do_random_connect_stations: + connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, + intern_connect_max_nbr_of_shortes_city) + else: + connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, + intern_connect_max_nbr_of_shortes_city) # ---------------------------------------------------------------------------------- # fix all transition at starting / ending points (mostly add a dead end, if missing) - # TODO i would like to remove the fixing stuff. + # TODO we might have to remove the fixing stuff in the future for i in range(len(nodes_added)): grid_map.fix_transitions(nodes_added[i]) # ---------------------------------------------------------------------------------- - # remove stations where rail is a switch - remove_switch_stations(rail_trans, grid_map, train_stations) + # remove stations where underlaying rail is a switch + remove_switch_stations(rail_trans, grid_map, train_stations_slots) # ---------------------------------------------------------------------------------- # Slot availability in node node_available_start = [] node_available_target = [] for node_idx in range(max_num_cities): - node_available_start.append(len(train_stations[node_idx])) - node_available_target.append(len(train_stations[node_idx])) + node_available_start.append(len(train_stations_slots[node_idx])) + node_available_target.append(len(train_stations_slots[node_idx])) # Assign agents to slots for agent_idx in range(num_agents): @@ -499,7 +495,7 @@ def city_generator(num_cities: int = 5, return grid_map, {'agents_hints': { 'num_agents': num_agents, 'agent_start_targets_nodes': agent_start_targets_nodes, - 'train_stations': train_stations + 'train_stations': train_stations_slots }} return generator