diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index f2c8e7b5f6d17a879f776457a8768fec22af87a4..a5c86aa1dc8bcdf425391c4f049e691a422c196d 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -13,14 +13,25 @@ from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool -def realistic_rail_generator(num_cities=5, city_size=10, allowed_rotation_angles=[0, 90], - max_number_of_station_tracks=4, max_number_of_connecting_tracks=4, - seed=0, print_out_info=True) -> RailGenerator: +def realistic_rail_generator(num_cities=5, + city_size=10, + allowed_rotation_angles=[0, 90], + max_number_of_station_tracks=4, + max_number_of_connecting_tracks=4, + do_random_connect_stations=False, + seed=0, + print_out_info=True) -> RailGenerator: """ This is a level generator which generates a realistic rail configurations - :param num_cities: Number of city node (can hold trainstations) + :param num_cities: Number of city node + :param city_size: Length of city measure in cells + :param allowed_rotation_angles: Rotate the city (around center) + :param max_number_of_station_tracks: max number of tracks per station + :param max_number_of_connecting_tracks: max number of connecting track between stations + :param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand :param seed: Random Seed + :print_out_info : print debug info :return: ------- numpy.ndarray of type numpy.uint16 @@ -149,7 +160,7 @@ def realistic_rail_generator(num_cities=5, city_size=10, allowed_rotation_angles return nodes_added, station_slots, start_nodes_added, end_nodes_added def connect_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added, - inter_max_number_of_connecting_tracks): + inter_max_number_of_connecting_tracks, do_random_connect_stations): x = np.arange(len(start_nodes_added)) random_city_idx = np.random.choice(x, len(x), False) for city_loop in range(len(random_city_idx) - 1): @@ -159,10 +170,15 @@ def realistic_rail_generator(num_cities=5, city_size=10, allowed_rotation_angles e_nodes = end_nodes_added[idx_b] max_input_output = max(len(s_nodes), len(e_nodes)) - max_input_output = min(inter_max_number_of_connecting_tracks,max_input_output) + max_input_output = min(inter_max_number_of_connecting_tracks, max_input_output) + + if do_random_connect_stations: + idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) + idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False) + else: + idx_s_nodes = np.arange(len(s_nodes)) + idx_e_nodes = np.arange(len(e_nodes)) - idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) - idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False) if len(idx_s_nodes) < max_input_output: idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len( idx_s_nodes))) @@ -172,9 +188,9 @@ def realistic_rail_generator(num_cities=5, city_size=10, allowed_rotation_angles idx_e_nodes))) if len(idx_s_nodes) > inter_max_number_of_connecting_tracks: - idx_s_nodes = np.random.choice(idx_s_nodes,inter_max_number_of_connecting_tracks,False) + idx_s_nodes = np.random.choice(idx_s_nodes, inter_max_number_of_connecting_tracks, False) if len(idx_e_nodes) > inter_max_number_of_connecting_tracks: - idx_e_nodes = np.random.choice(idx_e_nodes,inter_max_number_of_connecting_tracks,False) + idx_e_nodes = np.random.choice(idx_e_nodes, inter_max_number_of_connecting_tracks, False) for i in range(max_input_output): start_node = s_nodes[idx_s_nodes[i]] @@ -226,7 +242,8 @@ def realistic_rail_generator(num_cities=5, city_size=10, allowed_rotation_angles generate_city_locations, intern_max_number_of_station_tracks) # connect stations - connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, inter_max_number_of_connecting_tracks) + connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, inter_max_number_of_connecting_tracks, + do_random_connect_stations) # ---------------------------------------------------------------------------------- # fix all transition at starting / ending points (mostly add a dead end, if missing) @@ -285,7 +302,7 @@ for itrials in range(100): height=70, rail_generator=realistic_rail_generator(num_cities=np.random.choice(40) + 2, city_size=np.random.choice(10) + 10, - allowed_rotation_angles=[0], + allowed_rotation_angles=[0, 90], max_number_of_station_tracks=np.random.choice(6) + 4, max_number_of_connecting_tracks=4, # Number of cities in map