From fe03773e6e0c4c3f76b6d3a75d68ffa4a7e65230 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 12 Sep 2019 19:09:45 +0200 Subject: [PATCH] refactored --- .../Simple_Realistic_Railway_Generator.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 3d70935c..de3831a3 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -345,10 +345,31 @@ def realistic_rail_generator(num_cities=5, print("************* NBR of graphs:", len(np.unique(graph_ids))) return graph, np.unique(graph_ids).astype(int) + def connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added): + _, graphids = calc_nbr_of_graphs(city_edges) + if len(graphids) > 0: + for i in range(len(graphids) - 1): + connection = [] + cnt = 0 + while len(connection) == 0 and cnt < 100: + s_nodes = copy.deepcopy(org_s_nodes) + e_nodes = copy.deepcopy(org_e_nodes) + start_nodes = s_nodes[graphids[i]] + end_nodes = e_nodes[graphids[i + 1]] + start_node = start_nodes[np.random.choice(len(start_nodes))] + end_node = end_nodes[np.random.choice(len(end_nodes))] + rail_array[start_node] = 0 + rail_array[end_node] = 0 + connection = connect_rail(rail_trans, rail_array, start_node, end_node) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + cnt += 1 + def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added, inter_connect_max_nbr_of_shortes_city): - graph = [] + city_edges = [] s_nodes = copy.deepcopy(org_s_nodes) e_nodes = copy.deepcopy(org_e_nodes) @@ -382,33 +403,15 @@ def realistic_rail_generator(num_cities=5, a = (city_loop, cl, np.inf) if city_loop > cl: a = (cl, city_loop, np.inf) - if not (a in graph): - graph.append(a) + if not (a in city_edges): + city_edges.append(a) nodes_added.append(start_node) nodes_added.append(end_node) else: rail_array[start_node] = tmp_trans_sn rail_array[end_node] = tmp_trans_en - _, graphids = calc_nbr_of_graphs(graph) - if len(graphids) > 0: - for i in range(len(graphids) - 1): - connection = [] - cnt = 0 - while len(connection) == 0 and cnt < 100: - s_nodes = copy.deepcopy(org_s_nodes) - e_nodes = copy.deepcopy(org_e_nodes) - start_nodes = s_nodes[graphids[i]] - end_nodes = e_nodes[graphids[i + 1]] - start_node = start_nodes[np.random.choice(len(start_nodes))] - end_node = end_nodes[np.random.choice(len(end_nodes))] - rail_array[start_node] = 0 - rail_array[end_node] = 0 - connection = connect_rail(rail_trans, rail_array, start_node, end_node) - if len(connection) > 0: - nodes_added.append(start_node) - nodes_added.append(end_node) - cnt += 1 + connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added) def connect_random_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added, inter_connect_max_nbr_of_shortes_city): -- GitLab