From 4418d3a00a2983f742bbd6ed37fb093d3b9d2214 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Fri, 13 Sep 2019 08:28:50 +0200
Subject: [PATCH] refactoring

---
 .../Simple_Realistic_Railway_Generator.py     | 116 +++++++++---------
 1 file changed, 58 insertions(+), 58 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index 6f9061b0..41506ba9 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -142,56 +142,56 @@ class Vec2dOperations:
         return (x1, y1)
 
 
-def min_max_cut(min_v, max_v, v):
-    return max(min_v, min(max_v, v))
-
-
-def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
-    gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
-
-    lrcStroke = [[min_max_cut(0, height - 1, pt_from[0]),
-                  min_max_cut(0, width - 1, pt_from[1])],
-                 [min_max_cut(0, height - 1, pt_via[0]),
-                  min_max_cut(0, width - 1, pt_via[1])],
-                 [min_max_cut(0, height - 1, pt_to[0]),
-                  min_max_cut(0, width - 1, pt_to[1])]]
-
-    rc3Cells = np.array(lrcStroke[:3])  # the 3 cells
-    rcMiddle = rc3Cells[1]  # the middle cell which we will update
-    bDeadend = np.all(lrcStroke[0] == lrcStroke[2])  # deadend means cell 0 == cell 2
-
-    # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
-    rc2Trans = np.diff(rc3Cells, axis=0)
-
-    # get the direction index for the 2 transitions
-    liTrans = []
-    for rcTrans in rc2Trans:
-        # gRCTrans - rcTrans gives an array of vector differences between our rcTrans
-        # and the 4 directions stored in gRCTrans.
-        # Where the vector difference is zero, we have a match...
-        # np.all detects where the whole row,col vector is zero.
-        # argwhere gives the index of the zero vector, ie the direction index
-        iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
-        if len(iTrans) > 0:
-            iTrans = iTrans[0][0]
-            liTrans.append(iTrans)
-
-    # check that we have two transitions
-    if len(liTrans) == 2:
-        # Set the transition
-        # Set the transition
-        # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
-        # The user will need to resolve any conflicts.
-        grid_map.set_transition((*rcMiddle, liTrans[0]),
-                                liTrans[1],
-                                bAddRemove,
-                                remove_deadends=not bDeadend)
-
-        # Also set the reverse transition
-        # use the reversed outbound transition for inbound
-        # and the reversed inbound transition for outbound
-        grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
-                                mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
+class Grid_Map_Op:
+    def min_max_cut(min_v, max_v, v):
+        return max(min_v, min(max_v, v))
+
+    def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
+        gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
+
+        lrcStroke = [[Grid_Map_Op.min_max_cut(0, height - 1, pt_from[0]),
+                      Grid_Map_Op.min_max_cut(0, width - 1, pt_from[1])],
+                     [Grid_Map_Op.min_max_cut(0, height - 1, pt_via[0]),
+                      Grid_Map_Op.min_max_cut(0, width - 1, pt_via[1])],
+                     [Grid_Map_Op.min_max_cut(0, height - 1, pt_to[0]),
+                      Grid_Map_Op.min_max_cut(0, width - 1, pt_to[1])]]
+
+        rc3Cells = np.array(lrcStroke[:3])  # the 3 cells
+        rcMiddle = rc3Cells[1]  # the middle cell which we will update
+        bDeadend = np.all(lrcStroke[0] == lrcStroke[2])  # deadend means cell 0 == cell 2
+
+        # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
+        rc2Trans = np.diff(rc3Cells, axis=0)
+
+        # get the direction index for the 2 transitions
+        liTrans = []
+        for rcTrans in rc2Trans:
+            # gRCTrans - rcTrans gives an array of vector differences between our rcTrans
+            # and the 4 directions stored in gRCTrans.
+            # Where the vector difference is zero, we have a match...
+            # np.all detects where the whole row,col vector is zero.
+            # argwhere gives the index of the zero vector, ie the direction index
+            iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
+            if len(iTrans) > 0:
+                iTrans = iTrans[0][0]
+                liTrans.append(iTrans)
+
+        # check that we have two transitions
+        if len(liTrans) == 2:
+            # Set the transition
+            # Set the transition
+            # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
+            # The user will need to resolve any conflicts.
+            grid_map.set_transition((*rcMiddle, liTrans[0]),
+                                    liTrans[1],
+                                    bAddRemove,
+                                    remove_deadends=not bDeadend)
+
+            # Also set the reverse transition
+            # use the reversed outbound transition for inbound
+            # and the reversed inbound transition for outbound
+            grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
+                                    mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
 
 
 def realistic_rail_generator(num_cities=5,
@@ -309,7 +309,7 @@ def realistic_rail_generator(num_cities=5,
                 if len(data) > 2 and len(data1) > 2:
                     for i in np.random.choice(min(len(data1), len(data)) - 2,
                                               intern_nbr_of_switches_per_station_track):
-                        add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True)
+                        Grid_Map_Op.add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True)
                         nodes_added.append(data[i + 1])
                         nodes_added.append(data1[i + 1])
                         nodes_added.append(data1[i + 2])
@@ -579,15 +579,15 @@ def realistic_rail_generator(num_cities=5,
 for itrials in range(100):
     print(itrials, "generate new city")
     np.random.seed(int(time.time()))
-    env = RailEnv(width=40+np.random.choice(100),
-                  height=40+np.random.choice(100),
+    env = RailEnv(width=40 + np.random.choice(100),
+                  height=40 + np.random.choice(100),
                   rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
                                                           city_size=10 + np.random.choice(10),
-                                                          allowed_rotation_angles=[-90,-45,0,45,90],
-                                                          max_number_of_station_tracks=np.random.choice(4)+4,
-                                                          nbr_of_switches_per_station_track=np.random.choice(4)+2,
-                                                          connect_max_nbr_of_shortes_city=np.random.choice(4)+2,
-                                                          do_random_connect_stations=np.random.choice(1)==0,
+                                                          allowed_rotation_angles=[-90, -45, 0, 45, 90],
+                                                          max_number_of_station_tracks=np.random.choice(4) + 4,
+                                                          nbr_of_switches_per_station_track=np.random.choice(4) + 2,
+                                                          connect_max_nbr_of_shortes_city=np.random.choice(4) + 2,
+                                                          do_random_connect_stations=np.random.choice(1) == 0,
                                                           # Number of cities in map
                                                           seed=int(time.time()),  # Random seed
                                                           print_out_info=False
-- 
GitLab