From 2b748303805df307a9a8f50458df0eaa506e6bd2 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 17 Sep 2019 07:54:51 +0200
Subject: [PATCH] refactoring and clean up

---
 .../Simple_Realistic_Railway_Generator.py     | 113 +++++-------------
 1 file changed, 30 insertions(+), 83 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index f1649e78..0ea9202a 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -5,8 +5,7 @@ import warnings
 
 import numpy as np
 
-from flatland.core.grid.grid4_utils import mirror
-from flatland.core.grid.grid_utils import Vec2dOperations as vec2d
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
@@ -17,62 +16,9 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
 
-# TODO : remove (reuse existing code!!)
-class GripMapOp:
-    def min_max_cut(min_v, max_v, v):
-        return max(min_v, min(max_v, v))
-
-    def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
-        gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
-
-        lrcStroke = [[GripMapOp.min_max_cut(0, height - 1, pt_from[0]),
-                      GripMapOp.min_max_cut(0, width - 1, pt_from[1])],
-                     [GripMapOp.min_max_cut(0, height - 1, pt_via[0]),
-                      GripMapOp.min_max_cut(0, width - 1, pt_via[1])],
-                     [GripMapOp.min_max_cut(0, height - 1, pt_to[0]),
-                      GripMapOp.min_max_cut(0, width - 1, pt_to[1])]]
-
-        rc3Cells = np.array(lrcStroke[:3])  # the 3 cells
-        rcMiddle = rc3Cells[1]  # the middle cell which we will update
-        bDeadend = np.all(lrcStroke[0] == lrcStroke[2])  # deadend means cell 0 == cell 2
-
-        # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
-        rc2Trans = np.diff(rc3Cells, axis=0)
-
-        # get the direction index for the 2 transitions
-        liTrans = []
-        for rcTrans in rc2Trans:
-            # gRCTrans - rcTrans gives an array of vector differences between our rcTrans
-            # and the 4 directions stored in gRCTrans.
-            # Where the vector difference is zero, we have a match...
-            # np.all detects where the whole row,col vector is zero.
-            # argwhere gives the index of the zero vector, ie the direction index
-            iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
-            if len(iTrans) > 0:
-                iTrans = iTrans[0][0]
-                liTrans.append(iTrans)
-
-        # check that we have two transitions
-        if len(liTrans) == 2:
-            # Set the transition
-            # Set the transition
-            # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
-            # The user will need to resolve any conflicts.
-            grid_map.set_transition((*rcMiddle, liTrans[0]),
-                                    liTrans[1],
-                                    bAddRemove,
-                                    remove_deadends=not bDeadend)
-
-            # Also set the reverse transition
-            # use the reversed outbound transition for inbound
-            # and the reversed inbound transition for outbound
-            grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
-                                    mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
-
-
 def realistic_rail_generator(num_cities=5,
                              city_size=10,
-                             allowed_rotation_angles=[0, 90],
+                             allowed_rotation_angles=None,
                              max_number_of_station_tracks=4,
                              nbr_of_switches_per_station_track=2,
                              connect_max_nbr_of_shortes_city=4,
@@ -82,6 +28,7 @@ def realistic_rail_generator(num_cities=5,
     """
     This is a level generator which generates a realistic rail configurations
 
+    :param print_out_info:
     :param num_cities: Number of city node
     :param city_size: Length of city measure in cells
     :param allowed_rotation_angles: Rotate the city (around center)
@@ -117,25 +64,25 @@ def realistic_rail_generator(num_cities=5,
         generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))]
         return generate_city_locations, max_num_cities
 
-    def do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles):
+    def do_orient_cities(generate_city_locations, intern_city_size, rotation_angles_set):
         for i in range(len(generate_city_locations)):
             # station main orientation  (horizontal or vertical
-            rot_angle = np.random.choice(allowed_rotation_angles)
-            add_pos_val = vec2d.scale_pos(vec2d.rotate_pos((1, 0), rot_angle),
+            rot_angle = np.random.choice(rotation_angles_set)
+            add_pos_val = Vec2d.scale_pos(Vec2d.rotate_pos((1, 0), rot_angle),
                                           (max(1, (intern_city_size - 3) / 2)))
-            generate_city_locations[i][0] = vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
-            add_pos_val = vec2d.scale_pos(vec2d.rotate_pos((1, 0), 180 + rot_angle),
+            generate_city_locations[i][0] = Vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
+            add_pos_val = Vec2d.scale_pos(Vec2d.rotate_pos((1, 0), 180 + rot_angle),
                                           (max(1, (intern_city_size - 3) / 2)))
-            generate_city_locations[i][1] = vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
+            generate_city_locations[i][1] = Vec2d.add_pos(generate_city_locations[i][1], add_pos_val)
         return generate_city_locations
 
     def create_stations_from_city_locations(rail_trans, rail_array, generate_city_locations,
                                             intern_max_number_of_station_tracks):
         nodes_added = []
-        start_nodes_added = [[] for i in range(len(generate_city_locations))]
-        end_nodes_added = [[] for i in range(len(generate_city_locations))]
-        station_slots = [[] for i in range(len(generate_city_locations))]
-        station_tracks = [[[] for j in range(intern_max_number_of_station_tracks)] for i in range(len(
+        start_nodes_added = [[] for _ in range(len(generate_city_locations))]
+        end_nodes_added = [[] for _ in range(len(generate_city_locations))]
+        station_slots = [[] for _ in range(len(generate_city_locations))]
+        station_tracks = [[[] for _ in range(intern_max_number_of_station_tracks)] for _ in range(len(
             generate_city_locations))]
 
         station_slots_cnt = 0
@@ -147,13 +94,13 @@ def realistic_rail_generator(num_cities=5,
                 org_start_node = generate_city_locations[city_loop][0]
                 org_end_node = generate_city_locations[city_loop][1]
 
-                ortho_trans = vec2d.make_orthogonal_pos(
-                    vec2d.normalize_pos(vec2d.subtract_pos(org_start_node, org_end_node)))
+                ortho_trans = Vec2d.make_orthogonal_pos(
+                    Vec2d.normalize_pos(Vec2d.subtract_pos(org_start_node, org_end_node)))
                 s = (ct - number_of_connecting_tracks / 2.0)
-                start_node = vec2d.ceil_pos(
-                    vec2d.add_pos(org_start_node, vec2d.scale_pos(ortho_trans, s)))
-                end_node = vec2d.ceil_pos(
-                    vec2d.add_pos(org_end_node, vec2d.scale_pos(ortho_trans, s)))
+                start_node = Vec2d.ceil_pos(
+                    Vec2d.add_pos(org_start_node, Vec2d.scale_pos(ortho_trans, s)))
+                end_node = Vec2d.ceil_pos(
+                    Vec2d.add_pos(org_end_node, Vec2d.scale_pos(ortho_trans, s)))
 
                 connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node)
                 if len(connection) > 0:
@@ -174,8 +121,7 @@ def realistic_rail_generator(num_cities=5,
 
         return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks
 
-    def create_switches_at_stations(rail_trans, rail_array, width, height, grid_map, station_tracks, nodes_added,
-                                    intern_nbr_of_switches_per_station_track):
+    def create_switches_at_stations(rail_trans, rail_array, station_tracks, nodes_added):
 
         for city_loop in range(len(station_tracks)):
             datas = station_tracks[city_loop]
@@ -183,13 +129,14 @@ def realistic_rail_generator(num_cities=5,
                 a = datas[0]
                 if len(a) > 0:
                     start_node = a[np.random.choice(len(a) - 2) + 1]
-                    b = []
                     for i in np.arange(1, len(datas)):
                         b = datas[i]
                         if len(b) > 2:
                             x = np.random.choice(len(b) - 2) + 1
                             end_node = b[x]
                             connection = connect_rail(rail_trans, rail_array, start_node, end_node)
+                            if len(connection) == 0:
+                                print("create_switches_at_stations : connect_rail -> no path found")
                             nodes_added.append(start_node)
                             nodes_added.append(end_node)
                             start_node = b[np.random.choice(len(b) - 2) + 1]
@@ -230,8 +177,8 @@ def realistic_rail_generator(num_cities=5,
         if len(graphids) > 0:
             for i in range(len(graphids) - 1):
                 connection = []
-                cnt = 0
-                while len(connection) == 0 and cnt < 100:
+                iteration_counter = 0
+                while len(connection) == 0 and iteration_counter < 100:
                     s_nodes = copy.deepcopy(org_s_nodes)
                     e_nodes = copy.deepcopy(org_e_nodes)
                     start_nodes = s_nodes[graphids[i]]
@@ -247,7 +194,7 @@ def realistic_rail_generator(num_cities=5,
                     if len(connection) > 0:
                         nodes_added.append(start_node)
                         nodes_added.append(end_node)
-                    cnt += 1
+                    iteration_counter += 1
 
     def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added,
                          inter_connect_max_nbr_of_shortes_city):
@@ -263,12 +210,13 @@ def realistic_rail_generator(num_cities=5,
                 for start_node in sns:
                     min_distance = np.inf
                     end_node = None
+                    cl = 0
                     for city_loop_find_shortest in range(len(e_nodes)):
                         if city_loop_find_shortest == city_loop:
                             continue
                         ens = e_nodes[city_loop_find_shortest]
                         for en in ens:
-                            d = vec2d.get_norm_pos(vec2d.subtract_pos(en, start_node))
+                            d = Vec2d.get_norm_pos(Vec2d.subtract_pos(en, start_node))
                             if d < min_distance:
                                 min_distance = d
                                 end_node = en
@@ -336,8 +284,8 @@ def realistic_rail_generator(num_cities=5,
             for i in range(max_input_output):
                 start_node = s_nodes[idx_s_nodes[i]]
                 end_node = e_nodes[idx_e_nodes[i]]
-                new_trans = rail_array[start_node] = 0
-                new_trans = rail_array[end_node] = 0
+                rail_array[start_node] = 0
+                rail_array[end_node] = 0
                 connection = connect_nodes(rail_trans, rail_array, start_node, end_node)
                 if len(connection) > 0:
                     nodes_added.append(start_node)
@@ -398,8 +346,7 @@ def realistic_rail_generator(num_cities=5,
         # build switches
         # TODO remove true/false block
         if True:
-            create_switches_at_stations(rail_trans, rail_array, width, height, grid_map, station_tracks, nodes_added,
-                                        intern_nbr_of_switches_per_station_track)
+            create_switches_at_stations(rail_trans, rail_array, station_tracks, nodes_added)
 
         # ----------------------------------------------------------------------------------
         # connect stations
-- 
GitLab