diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index b0c4cc64e4cd27f9cf165058ccda1e9131fd0bbe..a86e95d2c486486341294dece158c7789218aea8 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -815,7 +815,13 @@ def sparse_rail_generator(num_cities=5, node_radius=2, return train_stations, built_num_trainstations def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): - + """ + Fill the trainstation positions with targets and goals + :param num_agents: + :param nb_nodes: + :param train_stations: + :return: + """ # Generate start and target node directory for all agents. # Assure that start and target are not in the same node agent_start_targets_nodes = [] @@ -831,8 +837,13 @@ def sparse_rail_generator(num_cities=5, node_radius=2, for agent_idx in range(num_agents): avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] - start_node = np.random.choice(avail_start_nodes) - target_node = np.random.choice(avail_target_nodes) + # Set probability to choose start and stop from trainstations + sum_start = sum(node_available_start) + sum_target = sum(node_available_target) + p_avail_start = [float(i) / sum_start for i in node_available_start] + p_avail_target = [float(i) / sum_target for i in node_available_target] + start_node = np.random.choice(avail_start_nodes, p=p_avail_start) + target_node = np.random.choice(avail_target_nodes, p=p_avail_target) tries = 0 found_agent_pair = True while target_node == start_node: