diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index e3b59e4fccba235b008eb8c576f5bdb334061248..e8b4f1fefac1dedada1949575105106995fbf172 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -998,7 +998,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 @@ -1097,7 +1096,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # We currently place them uniformly distirbuted among all cities if num_cities > 1: train_stations = [[] for i in range(num_cities)] - + built_num_trainstation = 0 for station in range(num_trainstations): trainstation_node = int(station / num_trainstations * num_cities) @@ -1130,9 +1129,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Check if connection was made if len(connection) == 0: train_stations[trainstation_node].pop(-1) + else: + built_num_trainstation += 1 # Adjust the number of agents if you could not build enough trainstations - built_num_trainstation = len(train_stations) + if num_agents > built_num_trainstation: num_agents = built_num_trainstation warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") @@ -1193,7 +1194,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 target_node = np.random.choice(avail_target_nodes) tries += 1 # Test again with new start node if no pair is found (This code needs to be improved) - if tries > 10: + if (tries + 1) % 10 == 0: start_node = np.random.choice(avail_start_nodes) if tries > 100: warnings.warn("Could not set trainstations, please change initial parameters!!!!")