diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 2eaa412f41a5a6116a7a1c3e995c2cf8db63f18d..bea2bcdc73a28888c795a1559b52358933240e21 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -235,11 +235,22 @@ def realistic_rail_generator(num_cities=5, # place in the center of path a station slot station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) station_slots_cnt += 1 - if len(connection) - 3 > 0: - idxs = np.random.choice(len(connection) - 3, intern_nbr_of_switches_per_station_track, False) - for idx in idxs: - switch_slots[city_loop].append(connection[idx + 1]) + # generate random switch positions (switch slots) + if len(connection) - 3 - nbr_of_switches_per_station_track - 1> 0: + idxs = np.sort(np.random.choice(np.arange(len(connection) - 3), + nbr_of_switches_per_station_track + 1,False)) + idx_loop_cnt = 0 + for idx in idxs: + pt = connection[idx + 1] + if idx_loop_cnt % 2 == 1: + s = (ct - number_of_connecting_tracks / 2.0) + pt = PositionOps.ceil_pos( + PositionOps.add_pos(pt, PositionOps.scale_pos(ortho_trans, s))) + switch_slots[city_loop].append(pt) + idx_loop_cnt += 1 + + # generate switch based on switch slot list and connect them for city_loop in range(len(switch_slots)): data = switch_slots[city_loop] data_idx = np.random.choice(np.arange(len(data)), len(data), False) @@ -248,8 +259,10 @@ def realistic_rail_generator(num_cities=5, end_node = data[data_idx[i + 1]] connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node) if len(connection) > 0: + station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) nodes_added.append(start_node) nodes_added.append(end_node) + new_trans = rail_array[end_node] = 0 if print_out_info: print("max nbr of station slots with given configuration is:", station_slots_cnt) @@ -275,6 +288,9 @@ def realistic_rail_generator(num_cities=5, b.append(np.argmin(a[i])) random_city_idx = np.argsort(b) + # cyclic connection + random_city_idx = np.append(random_city_idx,random_city_idx[0]) + for city_loop in range(len(random_city_idx) - 1): idx_a = random_city_idx[city_loop + 1] idx_b = random_city_idx[city_loop] @@ -350,19 +366,27 @@ def realistic_rail_generator(num_cities=5, print("inter_max_number_of_connecting_tracks:", inter_max_number_of_connecting_tracks) agent_start_targets_nodes = [] + + # ---------------------------------------------------------------------------------- # generate city locations generate_city_locations, max_num_cities = do_generate_city_locations(width, height, intern_city_size, intern_max_number_of_station_tracks) + + # ---------------------------------------------------------------------------------- # apply orientation to cities (horizontal, vertical) generate_city_locations = do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles) + + # ---------------------------------------------------------------------------------- # generate city topology nodes_added, train_stations, s_nodes, e_nodes = \ create_stations_from_city_locations(rail_trans, rail_array, generate_city_locations, intern_max_number_of_station_tracks, intern_nbr_of_switches_per_station_track) + # ---------------------------------------------------------------------------------- # connect stations - connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, inter_max_number_of_connecting_tracks, + connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, + inter_max_number_of_connecting_tracks, do_random_connect_stations) # ---------------------------------------------------------------------------------- @@ -418,15 +442,15 @@ def realistic_rail_generator(num_cities=5, for itrials in range(100): print(itrials, "generate new city") np.random.seed(int(time.time())) - env = RailEnv(width=80, - height=20, + env = RailEnv(width=120, + height=120, rail_generator=realistic_rail_generator(num_cities=10, - city_size=10, - allowed_rotation_angles=[-90], + city_size=20, + allowed_rotation_angles=[-90,0,90], max_number_of_station_tracks=4, nbr_of_switches_per_station_track=2, - max_number_of_connecting_tracks=10, - do_random_connect_stations=False, + max_number_of_connecting_tracks=1, + do_random_connect_stations=True, # Number of cities in map seed=int(time.time()) # Random seed ),