From a85b1df0c609f425a19434708ee68852b3ab9662 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 23 Aug 2019 17:34:50 -0400 Subject: [PATCH] added stability measures for trainstation setting and filling --- flatland/envs/generators.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index e8b4f1fe..f79c5600 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1190,6 +1190,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 start_node = np.random.choice(avail_start_nodes) target_node = np.random.choice(avail_target_nodes) tries = 0 + found_agent_pair = True while target_node == start_node: target_node = np.random.choice(avail_target_nodes) tries += 1 @@ -1197,13 +1198,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if (tries + 1) % 10 == 0: start_node = np.random.choice(avail_start_nodes) if tries > 100: - warnings.warn("Could not set trainstations, please change initial parameters!!!!") + warnings.warn("Could not set trainstations, removing agent!") + found_agent_pair = False break - - node_available_start[start_node] -= 1 - node_available_target[target_node] -= 1 - - agent_start_targets_nodes.append((start_node, target_node)) + if found_agent_pair: + node_available_start[start_node] -= 1 + node_available_target[target_node] -= 1 + agent_start_targets_nodes.append((start_node, target_node)) + else: + num_agents -= 1 # Place agents and targets within available train stations agents_position = [] @@ -1211,7 +1214,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 agents_direction = [] for agent_idx in range(num_agents): - # Set target for agent current_target_node = agent_start_targets_nodes[agent_idx][1] target_station_idx = np.random.randint(len(train_stations[current_target_node])) @@ -1222,7 +1224,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 target = train_stations[current_target_node][target_station_idx] tries += 1 if tries > 100: - warnings.warn("Could not set target position, please change initial parameters!!!!") + warnings.warn("Could not set target position, removing an agent") break agents_target.append((target[0], target[1])) -- GitLab