Skip to content
Snippets Groups Projects
Commit 4418d3a0 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring

parent 7896afd0
No related branches found
No related tags found
No related merge requests found
...@@ -142,56 +142,56 @@ class Vec2dOperations: ...@@ -142,56 +142,56 @@ class Vec2dOperations:
return (x1, y1) return (x1, y1)
def min_max_cut(min_v, max_v, v): class Grid_Map_Op:
return max(min_v, min(max_v, v)) def min_max_cut(min_v, max_v, v):
return max(min_v, min(max_v, v))
def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True): def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC
lrcStroke = [[min_max_cut(0, height - 1, pt_from[0]), lrcStroke = [[Grid_Map_Op.min_max_cut(0, height - 1, pt_from[0]),
min_max_cut(0, width - 1, pt_from[1])], Grid_Map_Op.min_max_cut(0, width - 1, pt_from[1])],
[min_max_cut(0, height - 1, pt_via[0]), [Grid_Map_Op.min_max_cut(0, height - 1, pt_via[0]),
min_max_cut(0, width - 1, pt_via[1])], Grid_Map_Op.min_max_cut(0, width - 1, pt_via[1])],
[min_max_cut(0, height - 1, pt_to[0]), [Grid_Map_Op.min_max_cut(0, height - 1, pt_to[0]),
min_max_cut(0, width - 1, pt_to[1])]] Grid_Map_Op.min_max_cut(0, width - 1, pt_to[1])]]
rc3Cells = np.array(lrcStroke[:3]) # the 3 cells rc3Cells = np.array(lrcStroke[:3]) # the 3 cells
rcMiddle = rc3Cells[1] # the middle cell which we will update rcMiddle = rc3Cells[1] # the middle cell which we will update
bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2 bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2
# get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
rc2Trans = np.diff(rc3Cells, axis=0) rc2Trans = np.diff(rc3Cells, axis=0)
# get the direction index for the 2 transitions # get the direction index for the 2 transitions
liTrans = [] liTrans = []
for rcTrans in rc2Trans: for rcTrans in rc2Trans:
# gRCTrans - rcTrans gives an array of vector differences between our rcTrans # gRCTrans - rcTrans gives an array of vector differences between our rcTrans
# and the 4 directions stored in gRCTrans. # and the 4 directions stored in gRCTrans.
# Where the vector difference is zero, we have a match... # Where the vector difference is zero, we have a match...
# np.all detects where the whole row,col vector is zero. # np.all detects where the whole row,col vector is zero.
# argwhere gives the index of the zero vector, ie the direction index # argwhere gives the index of the zero vector, ie the direction index
iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1)) iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
if len(iTrans) > 0: if len(iTrans) > 0:
iTrans = iTrans[0][0] iTrans = iTrans[0][0]
liTrans.append(iTrans) liTrans.append(iTrans)
# check that we have two transitions # check that we have two transitions
if len(liTrans) == 2: if len(liTrans) == 2:
# Set the transition # Set the transition
# Set the transition # Set the transition
# If this transition spans 3 cells, it is not a deadend, so remove any deadends. # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
# The user will need to resolve any conflicts. # The user will need to resolve any conflicts.
grid_map.set_transition((*rcMiddle, liTrans[0]), grid_map.set_transition((*rcMiddle, liTrans[0]),
liTrans[1], liTrans[1],
bAddRemove, bAddRemove,
remove_deadends=not bDeadend) remove_deadends=not bDeadend)
# Also set the reverse transition # Also set the reverse transition
# use the reversed outbound transition for inbound # use the reversed outbound transition for inbound
# and the reversed inbound transition for outbound # and the reversed inbound transition for outbound
grid_map.set_transition((*rcMiddle, mirror(liTrans[1])), grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
def realistic_rail_generator(num_cities=5, def realistic_rail_generator(num_cities=5,
...@@ -309,7 +309,7 @@ def realistic_rail_generator(num_cities=5, ...@@ -309,7 +309,7 @@ def realistic_rail_generator(num_cities=5,
if len(data) > 2 and len(data1) > 2: if len(data) > 2 and len(data1) > 2:
for i in np.random.choice(min(len(data1), len(data)) - 2, for i in np.random.choice(min(len(data1), len(data)) - 2,
intern_nbr_of_switches_per_station_track): intern_nbr_of_switches_per_station_track):
add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True) Grid_Map_Op.add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True)
nodes_added.append(data[i + 1]) nodes_added.append(data[i + 1])
nodes_added.append(data1[i + 1]) nodes_added.append(data1[i + 1])
nodes_added.append(data1[i + 2]) nodes_added.append(data1[i + 2])
...@@ -579,15 +579,15 @@ def realistic_rail_generator(num_cities=5, ...@@ -579,15 +579,15 @@ def realistic_rail_generator(num_cities=5,
for itrials in range(100): for itrials in range(100):
print(itrials, "generate new city") print(itrials, "generate new city")
np.random.seed(int(time.time())) np.random.seed(int(time.time()))
env = RailEnv(width=40+np.random.choice(100), env = RailEnv(width=40 + np.random.choice(100),
height=40+np.random.choice(100), height=40 + np.random.choice(100),
rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
city_size=10 + np.random.choice(10), city_size=10 + np.random.choice(10),
allowed_rotation_angles=[-90,-45,0,45,90], allowed_rotation_angles=[-90, -45, 0, 45, 90],
max_number_of_station_tracks=np.random.choice(4)+4, max_number_of_station_tracks=np.random.choice(4) + 4,
nbr_of_switches_per_station_track=np.random.choice(4)+2, nbr_of_switches_per_station_track=np.random.choice(4) + 2,
connect_max_nbr_of_shortes_city=np.random.choice(4)+2, connect_max_nbr_of_shortes_city=np.random.choice(4) + 2,
do_random_connect_stations=np.random.choice(1)==0, do_random_connect_stations=np.random.choice(1) == 0,
# Number of cities in map # Number of cities in map
seed=int(time.time()), # Random seed seed=int(time.time()), # Random seed
print_out_info=False print_out_info=False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment