diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index a5c86aa1dc8bcdf425391c4f049e691a422c196d..8176a3b4392f859116e10b5bd7facea010c5026d 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -140,7 +140,7 @@ def realistic_rail_generator(num_cities=5, station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) station_slots_cnt += 1 if len(connection) - 3 > 0: - idxs = np.random.choice(len(connection) - 2, 1 + np.random.choice(len(connection) - 3), False) + idxs = np.random.choice(len(connection) - 3, 1 + np.random.choice(len(connection) - 3), False) for idx in idxs: switch_slots[city_loop].append(connection[idx + 1]) @@ -162,10 +162,24 @@ def realistic_rail_generator(num_cities=5, def connect_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added, inter_max_number_of_connecting_tracks, do_random_connect_stations): x = np.arange(len(start_nodes_added)) - random_city_idx = np.random.choice(x, len(x), False) + if do_random_connect_stations: + random_city_idx = np.random.choice(x, len(x), False) + else: + a = [[] for i in x] + b = [] + for yLoop in x: + for xLoop in x: + v = get_norm_pos(subtract_pos(start_nodes_added[xLoop][0], end_nodes_added[yLoop][0])) + if v > 0: + v = np.inf + a[yLoop].append(v) + for i in range(len(a)): + b.append(np.argmin(a[i])) + random_city_idx = np.argsort(b) + for city_loop in range(len(random_city_idx) - 1): - idx_a = random_city_idx[city_loop] - idx_b = random_city_idx[city_loop + 1] + idx_a = random_city_idx[city_loop + 1] + idx_b = random_city_idx[city_loop] s_nodes = start_nodes_added[idx_a] e_nodes = end_nodes_added[idx_b] @@ -300,11 +314,12 @@ for itrials in range(100): np.random.seed(int(time.time())) env = RailEnv(width=70, height=70, - rail_generator=realistic_rail_generator(num_cities=np.random.choice(40) + 2, - city_size=np.random.choice(10) + 10, - allowed_rotation_angles=[0, 90], - max_number_of_station_tracks=np.random.choice(6) + 4, - max_number_of_connecting_tracks=4, + rail_generator=realistic_rail_generator(num_cities=20, + city_size=10, + allowed_rotation_angles=[90], + max_number_of_station_tracks=1, + max_number_of_connecting_tracks=1, + do_random_connect_stations=False, # Number of cities in map seed=int(time.time()) # Random seed ),