From 4418d3a00a2983f742bbd6ed37fb093d3b9d2214 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Fri, 13 Sep 2019 08:28:50 +0200 Subject: [PATCH] refactoring --- .../Simple_Realistic_Railway_Generator.py | 116 +++++++++--------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 6f9061b0..41506ba9 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -142,56 +142,56 @@ class Vec2dOperations: return (x1, y1) -def min_max_cut(min_v, max_v, v): - return max(min_v, min(max_v, v)) - - -def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True): - gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC - - lrcStroke = [[min_max_cut(0, height - 1, pt_from[0]), - min_max_cut(0, width - 1, pt_from[1])], - [min_max_cut(0, height - 1, pt_via[0]), - min_max_cut(0, width - 1, pt_via[1])], - [min_max_cut(0, height - 1, pt_to[0]), - min_max_cut(0, width - 1, pt_to[1])]] - - rc3Cells = np.array(lrcStroke[:3]) # the 3 cells - rcMiddle = rc3Cells[1] # the middle cell which we will update - bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2 - - # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East - rc2Trans = np.diff(rc3Cells, axis=0) - - # get the direction index for the 2 transitions - liTrans = [] - for rcTrans in rc2Trans: - # gRCTrans - rcTrans gives an array of vector differences between our rcTrans - # and the 4 directions stored in gRCTrans. - # Where the vector difference is zero, we have a match... - # np.all detects where the whole row,col vector is zero. - # argwhere gives the index of the zero vector, ie the direction index - iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1)) - if len(iTrans) > 0: - iTrans = iTrans[0][0] - liTrans.append(iTrans) - - # check that we have two transitions - if len(liTrans) == 2: - # Set the transition - # Set the transition - # If this transition spans 3 cells, it is not a deadend, so remove any deadends. - # The user will need to resolve any conflicts. - grid_map.set_transition((*rcMiddle, liTrans[0]), - liTrans[1], - bAddRemove, - remove_deadends=not bDeadend) - - # Also set the reverse transition - # use the reversed outbound transition for inbound - # and the reversed inbound transition for outbound - grid_map.set_transition((*rcMiddle, mirror(liTrans[1])), - mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) +class Grid_Map_Op: + def min_max_cut(min_v, max_v, v): + return max(min_v, min(max_v, v)) + + def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True): + gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC + + lrcStroke = [[Grid_Map_Op.min_max_cut(0, height - 1, pt_from[0]), + Grid_Map_Op.min_max_cut(0, width - 1, pt_from[1])], + [Grid_Map_Op.min_max_cut(0, height - 1, pt_via[0]), + Grid_Map_Op.min_max_cut(0, width - 1, pt_via[1])], + [Grid_Map_Op.min_max_cut(0, height - 1, pt_to[0]), + Grid_Map_Op.min_max_cut(0, width - 1, pt_to[1])]] + + rc3Cells = np.array(lrcStroke[:3]) # the 3 cells + rcMiddle = rc3Cells[1] # the middle cell which we will update + bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2 + + # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East + rc2Trans = np.diff(rc3Cells, axis=0) + + # get the direction index for the 2 transitions + liTrans = [] + for rcTrans in rc2Trans: + # gRCTrans - rcTrans gives an array of vector differences between our rcTrans + # and the 4 directions stored in gRCTrans. + # Where the vector difference is zero, we have a match... + # np.all detects where the whole row,col vector is zero. + # argwhere gives the index of the zero vector, ie the direction index + iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1)) + if len(iTrans) > 0: + iTrans = iTrans[0][0] + liTrans.append(iTrans) + + # check that we have two transitions + if len(liTrans) == 2: + # Set the transition + # Set the transition + # If this transition spans 3 cells, it is not a deadend, so remove any deadends. + # The user will need to resolve any conflicts. + grid_map.set_transition((*rcMiddle, liTrans[0]), + liTrans[1], + bAddRemove, + remove_deadends=not bDeadend) + + # Also set the reverse transition + # use the reversed outbound transition for inbound + # and the reversed inbound transition for outbound + grid_map.set_transition((*rcMiddle, mirror(liTrans[1])), + mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) def realistic_rail_generator(num_cities=5, @@ -309,7 +309,7 @@ def realistic_rail_generator(num_cities=5, if len(data) > 2 and len(data1) > 2: for i in np.random.choice(min(len(data1), len(data)) - 2, intern_nbr_of_switches_per_station_track): - add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True) + Grid_Map_Op.add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True) nodes_added.append(data[i + 1]) nodes_added.append(data1[i + 1]) nodes_added.append(data1[i + 2]) @@ -579,15 +579,15 @@ def realistic_rail_generator(num_cities=5, for itrials in range(100): print(itrials, "generate new city") np.random.seed(int(time.time())) - env = RailEnv(width=40+np.random.choice(100), - height=40+np.random.choice(100), + env = RailEnv(width=40 + np.random.choice(100), + height=40 + np.random.choice(100), rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), city_size=10 + np.random.choice(10), - allowed_rotation_angles=[-90,-45,0,45,90], - max_number_of_station_tracks=np.random.choice(4)+4, - nbr_of_switches_per_station_track=np.random.choice(4)+2, - connect_max_nbr_of_shortes_city=np.random.choice(4)+2, - do_random_connect_stations=np.random.choice(1)==0, + allowed_rotation_angles=[-90, -45, 0, 45, 90], + max_number_of_station_tracks=np.random.choice(4) + 4, + nbr_of_switches_per_station_track=np.random.choice(4) + 2, + connect_max_nbr_of_shortes_city=np.random.choice(4) + 2, + do_random_connect_stations=np.random.choice(1) == 0, # Number of cities in map seed=int(time.time()), # Random seed print_out_info=False -- GitLab