diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index d7050d63d0938339a9953e6351173806ffd356af..ad5c78a10f47a03e40880301cc6fdfdda5135feb 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -604,13 +604,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, print("City build time", time.time() - city_build_time) # Populate cities train_station_time = time.time() - train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, free_tracks, grid_map) + train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, node_radius, free_tracks, + grid_map) print("Trainstation placing time", time.time() - train_station_time) - # Adjust the number of agents if you could not build enough trainstations - if num_agents > built_num_trainstation: - num_agents = built_num_trainstation - warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") # Fix all transition elements grid_fix_time = time.time() @@ -822,7 +819,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, free_tracks[current_city].append(current_track) return through_path_cells, free_tracks - def _set_trainstation_positions(node_positions, free_tracks, grid_map): + def _set_trainstation_positions(node_positions, node_radius, free_tracks, grid_map): """ :param node_positions: @@ -836,20 +833,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, built_num_trainstations = 0 for current_city in range(len(node_positions)): for track_nbr in range(len(free_tracks[current_city])): - for possible_location in free_tracks[current_city][track_nbr]: - # Only build trainstation on non diverging elements - cell_type = grid_map.get_full_transitions(*possible_location) - nbits = 0 - while cell_type > 0: - nbits += (cell_type & 1) - cell_type = cell_type >> 1 - if 1 <= nbits <= 2: - built_num_trainstations += 1 - if track_nbr % 2 == 0: - left += 1 - else: - right += 1 - train_stations[current_city].append((possible_location, track_nbr)) + possible_location = free_tracks[current_city][track_nbr][node_radius] + train_stations[current_city].append((possible_location, track_nbr)) return train_stations, built_num_trainstations def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): @@ -881,29 +866,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, p_avail_start = [float(i) / sum_start for i in np.array(node_available_start)[avail_start_nodes]] p_avail_target = [float(i) / sum_target for i in np.array(node_available_target)[avail_target_nodes]] - start_node = np.random.choice(avail_start_nodes, p=p_avail_start) - target_node = np.random.choice(avail_target_nodes, p=p_avail_target) - tries = 0 - found_agent_pair = True - while target_node == start_node: - target_node = np.random.choice(avail_target_nodes) - tries += 1 - # Test again with new start node if no pair is found (This code needs to be improved) - if (tries + 1) % 10 == 0: - start_node = np.random.choice(avail_start_nodes) - if tries > 100: - warnings.warn("Could not set trainstations, removing agent!") - found_agent_pair = False - break - if found_agent_pair: - node_available_start[start_node] -= 1 - node_available_target[target_node] -= 1 - shortest_path = nx.astar_path(G, start_node, target_node, weight='length') - start_orientation = nx.get_edge_attributes(G, "direction")[(shortest_path[0], shortest_path[1])] - agent_start_targets_nodes.append((start_node, target_node, start_orientation)) - - else: - num_agents -= 1 + start_target_tuple = np.random.choice(avail_start_nodes, p=p_avail_start, size=2, replace=False) + start_node = start_target_tuple[0] + target_node = start_target_tuple[1] + agent_start_targets_nodes.append((start_node, target_node, 0)) return agent_start_targets_nodes, num_agents def _fix_transitions(city_cells, inter_city_lines, grid_map): diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index d6f5a2c3d50bbf4b2d95113144e10f7eef15291e..30f6b57e2f5376ed188441185e4afc1ee8602ffb 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -70,47 +70,20 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> agents_position = [] agents_target = [] agents_direction = [] - start_slots = train_stations - target_slots = train_stations for agent_idx in range(num_agents): # Set target for agent - current_target_node = agent_start_targets_nodes[agent_idx][1] - target_station_idx = np.random.randint(len(target_slots[current_target_node])) - target = target_slots[current_target_node][target_station_idx] - target_slots[current_target_node].pop(target_station_idx) - agents_target.append((target[0][0], target[0][1])) - - # Set start for agent and corresponding orientation - current_start_node = agent_start_targets_nodes[agent_idx][0] - agent_start_orientation = agent_start_targets_nodes[agent_idx][2] - - # Place the agent on the corresponding track - if city_orientations[current_start_node] == agent_start_orientation: - track_to_use = 0 - else: - track_to_use = 1 - - for i in range(len(start_slots[current_start_node])): - if start_slots[current_start_node][i][1] == track_to_use: - start_station_idx = i - break - else: - start_station_idx = None - - if start_station_idx is None: - warnings.warn("No slot available with required start orientation") - continue - start = start_slots[current_start_node][start_station_idx] - start_slots[current_start_node].pop(start_station_idx) - + start_city = agent_start_targets_nodes[agent_idx][0] + target_city = agent_start_targets_nodes[agent_idx][0] + agent_orientation = agent_start_targets_nodes[agent_idx][2] + start_city_idx = np.random.randint(len(train_stations[start_city])) + start = train_stations[start_city][start_city_idx] + target_station_idx = np.random.randint(len(train_stations[target_city])) + target = train_stations[target_city][target_station_idx] agents_position.append((start[0][0], start[0][1])) - agents_direction.append(agent_start_orientation) + agents_target.append((target[0][0], target[0][1])) + agents_direction.append(agent_orientation) # Orient the agent correctly - for agent_idx in range(1, len(agents_position)): - agents_position[agent_idx] = agents_position[0] - agents_direction[agent_idx] = (agents_direction[0] + np.random.choice([0, 2])) % 4 - if speed_ratio_map: speeds = speed_initialization_helper(num_agents, speed_ratio_map) else: