From 1c5a9215893209fcd48da6cb144b2ecc403e971f Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 23 Aug 2019 17:22:56 -0400 Subject: [PATCH] added stability measures for trainstation setting and filling --- flatland/envs/generators.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index e3b59e4f..e8b4f1fe 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -998,7 +998,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 @@ -1097,7 +1096,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # We currently place them uniformly distirbuted among all cities if num_cities > 1: train_stations = [[] for i in range(num_cities)] - + built_num_trainstation = 0 for station in range(num_trainstations): trainstation_node = int(station / num_trainstations * num_cities) @@ -1130,9 +1129,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Check if connection was made if len(connection) == 0: train_stations[trainstation_node].pop(-1) + else: + built_num_trainstation += 1 # Adjust the number of agents if you could not build enough trainstations - built_num_trainstation = len(train_stations) + if num_agents > built_num_trainstation: num_agents = built_num_trainstation warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") @@ -1193,7 +1194,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 target_node = np.random.choice(avail_target_nodes) tries += 1 # Test again with new start node if no pair is found (This code needs to be improved) - if tries > 10: + if (tries + 1) % 10 == 0: start_node = np.random.choice(avail_start_nodes) if tries > 100: warnings.warn("Could not set trainstations, please change initial parameters!!!!") -- GitLab