From fe03773e6e0c4c3f76b6d3a75d68ffa4a7e65230 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 12 Sep 2019 19:09:45 +0200
Subject: [PATCH] refactored

---
 .../Simple_Realistic_Railway_Generator.py     | 47 ++++++++++---------
 1 file changed, 25 insertions(+), 22 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index 3d70935c..de3831a3 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -345,10 +345,31 @@ def realistic_rail_generator(num_cities=5,
             print("************* NBR of graphs:", len(np.unique(graph_ids)))
         return graph, np.unique(graph_ids).astype(int)
 
+    def connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added):
+        _, graphids = calc_nbr_of_graphs(city_edges)
+        if len(graphids) > 0:
+            for i in range(len(graphids) - 1):
+                connection = []
+                cnt = 0
+                while len(connection) == 0 and cnt < 100:
+                    s_nodes = copy.deepcopy(org_s_nodes)
+                    e_nodes = copy.deepcopy(org_e_nodes)
+                    start_nodes = s_nodes[graphids[i]]
+                    end_nodes = e_nodes[graphids[i + 1]]
+                    start_node = start_nodes[np.random.choice(len(start_nodes))]
+                    end_node = end_nodes[np.random.choice(len(end_nodes))]
+                    rail_array[start_node] = 0
+                    rail_array[end_node] = 0
+                    connection = connect_rail(rail_trans, rail_array, start_node, end_node)
+                    if len(connection) > 0:
+                        nodes_added.append(start_node)
+                        nodes_added.append(end_node)
+                    cnt += 1
+
     def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added,
                          inter_connect_max_nbr_of_shortes_city):
 
-        graph = []
+        city_edges = []
 
         s_nodes = copy.deepcopy(org_s_nodes)
         e_nodes = copy.deepcopy(org_e_nodes)
@@ -382,33 +403,15 @@ def realistic_rail_generator(num_cities=5,
                             a = (city_loop, cl, np.inf)
                             if city_loop > cl:
                                 a = (cl, city_loop, np.inf)
-                            if not (a in graph):
-                                graph.append(a)
+                            if not (a in city_edges):
+                                city_edges.append(a)
                             nodes_added.append(start_node)
                             nodes_added.append(end_node)
                         else:
                             rail_array[start_node] = tmp_trans_sn
                             rail_array[end_node] = tmp_trans_en
 
-        _, graphids = calc_nbr_of_graphs(graph)
-        if len(graphids) > 0:
-            for i in range(len(graphids) - 1):
-                connection = []
-                cnt = 0
-                while len(connection) == 0 and cnt < 100:
-                    s_nodes = copy.deepcopy(org_s_nodes)
-                    e_nodes = copy.deepcopy(org_e_nodes)
-                    start_nodes = s_nodes[graphids[i]]
-                    end_nodes = e_nodes[graphids[i + 1]]
-                    start_node = start_nodes[np.random.choice(len(start_nodes))]
-                    end_node = end_nodes[np.random.choice(len(end_nodes))]
-                    rail_array[start_node] = 0
-                    rail_array[end_node] = 0
-                    connection = connect_rail(rail_trans, rail_array, start_node, end_node)
-                    if len(connection) > 0:
-                        nodes_added.append(start_node)
-                        nodes_added.append(end_node)
-                    cnt += 1
+        connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added)
 
     def connect_random_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added,
                                 inter_connect_max_nbr_of_shortes_city):
-- 
GitLab