diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index ecc4d7ec40fb1ad5b568c6e34dcf040d182e3f0d..51703606dc1d5144672e1036b707437a0c414a88 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -13,7 +13,10 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct from flatland.envs.schedule_generators import sparse_schedule_generator -from flatland.utils.rendertools import RenderTool, AgentRenderVariant +from flatland.utils.rendertools import AgentRenderVariant, RenderTool + +IntVector2DArrayType = [] +FloatArrayType = [] def realistic_rail_generator(num_cities=5, @@ -44,7 +47,10 @@ def realistic_rail_generator(num_cities=5, The matrix with the correct 16-bit bitmaps for each cell. """ - def do_generate_city_locations(width, height, intern_city_size, intern_max_number_of_station_tracks): + def do_generate_city_locations(width: int, + height: int, + intern_city_size: int, + intern_max_number_of_station_tracks: int) -> (IntVector2DArrayType, int): X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) @@ -64,20 +70,28 @@ def realistic_rail_generator(num_cities=5, generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))] return generate_city_locations, max_num_cities - def do_orient_cities(generate_city_locations, intern_city_size, rotation_angles_set): + def do_orient_cities(generate_city_locations: IntVector2DArrayType, intern_city_size: int, + rotation_angles_set: FloatArrayType): for i in range(len(generate_city_locations)): # station main orientation (horizontal or vertical rot_angle = np.random.choice(rotation_angles_set) add_pos_val = Vec2d.scale_pos(Vec2d.rotate_pos((1, 0), rot_angle), - (max(1, (intern_city_size - 3) / 2))) + int(max(1.0, (intern_city_size - 3) / 2))) generate_city_locations[i][0] = Vec2d.add_pos(generate_city_locations[i][1], add_pos_val) add_pos_val = Vec2d.scale_pos(Vec2d.rotate_pos((1, 0), 180 + rot_angle), - (max(1, (intern_city_size - 3) / 2))) + int(max(1.0, (intern_city_size - 3) / 2))) generate_city_locations[i][1] = Vec2d.add_pos(generate_city_locations[i][1], add_pos_val) return generate_city_locations - def create_stations_from_city_locations(rail_trans, rail_array, generate_city_locations, - intern_max_number_of_station_tracks): + def create_stations_from_city_locations(rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, + generate_city_locations: IntVector2DArrayType, + intern_max_number_of_station_tracks: int) -> (IntVector2DArrayType, + IntVector2DArrayType, + IntVector2DArrayType, + IntVector2DArrayType, + IntVector2DArrayType): + nodes_added = [] start_nodes_added = [[] for _ in range(len(generate_city_locations))] end_nodes_added = [[] for _ in range(len(generate_city_locations))] @@ -102,7 +116,7 @@ def realistic_rail_generator(num_cities=5, end_node = Vec2d.ceil_pos( Vec2d.add_pos(org_end_node, Vec2d.scale_pos(ortho_trans, s))) - connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node) + connection = connect_from_nodes(rail_trans, grid_map, start_node, end_node) if len(connection) > 0: nodes_added.append(start_node) nodes_added.append(end_node) @@ -124,8 +138,9 @@ def realistic_rail_generator(num_cities=5, return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks - def create_switches_at_stations(rail_trans, rail_array, station_tracks, nodes_added): - + def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + station_tracks: IntVector2DArrayType, + nodes_added: IntVector2DArrayType) -> IntVector2DArrayType: for city_loop in range(len(station_tracks)): datas = station_tracks[city_loop] if len(datas) > 1: @@ -137,7 +152,7 @@ def realistic_rail_generator(num_cities=5, if len(b) > 2: x = np.random.choice(len(b) - 2) + 1 end_node = b[x] - connection = connect_rail(rail_trans, rail_array, start_node, end_node) + connection = connect_rail(rail_trans, grid_map, start_node, end_node) if len(connection) == 0: if print_out_info: print("create_switches_at_stations : connect_rail -> no path found") @@ -176,7 +191,11 @@ def realistic_rail_generator(num_cities=5, print("************* NBR of graphs:", len(np.unique(graph_ids))) return graph, np.unique(graph_ids).astype(int) - def connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added): + def connect_sub_graphs(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + org_s_nodes: IntVector2DArrayType, + org_e_nodes: IntVector2DArrayType, + city_edges: IntVector2DArrayType, + nodes_added: IntVector2DArrayType): _, graphids = calc_nbr_of_graphs(city_edges) if len(graphids) > 0: for i in range(len(graphids) - 1): @@ -192,9 +211,9 @@ def realistic_rail_generator(num_cities=5, # TODO : removing, what the hell is going on, why we have to set rail_array -> transition to zero # TODO : before we can call connect_rail. If we don't reset the transistion to zero -> no rail # TODO : will be generated. - rail_array[start_node] = 0 - rail_array[end_node] = 0 - connection = connect_rail(rail_trans, rail_array, start_node, end_node) + grid_map.grid[start_node] = 0 + grid_map.grid[end_node] = 0 + connection = connect_rail(rail_trans, grid_map, start_node, end_node) if len(connection) > 0: nodes_added.append(start_node) nodes_added.append(end_node) @@ -204,9 +223,12 @@ def realistic_rail_generator(num_cities=5, iteration_counter += 1 - def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added, - inter_connect_max_nbr_of_shortes_city): - + def connect_stations(rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, + org_s_nodes: IntVector2DArrayType, + org_e_nodes: IntVector2DArrayType, + nodes_added: IntVector2DArrayType, + inter_connect_max_nbr_of_shortes_city: int): city_edges = [] s_nodes = copy.deepcopy(org_s_nodes) @@ -231,11 +253,11 @@ def realistic_rail_generator(num_cities=5, cl = city_loop_find_shortest if end_node is not None: - tmp_trans_sn = rail_array[start_node] - tmp_trans_en = rail_array[end_node] - rail_array[start_node] = 0 - rail_array[end_node] = 0 - connection = connect_rail(rail_trans, rail_array, start_node, end_node) + tmp_trans_sn = grid_map.grid[start_node] + tmp_trans_en = grid_map.grid[end_node] + grid_map.grid[start_node] = 0 + grid_map.grid[end_node] = 0 + connection = connect_rail(rail_trans, grid_map, start_node, end_node) if len(connection) > 0: s_nodes[city_loop].remove(start_node) e_nodes[cl].remove(end_node) @@ -250,13 +272,16 @@ def realistic_rail_generator(num_cities=5, if print_out_info: print("connect_stations : connect_rail -> no path found") - rail_array[start_node] = tmp_trans_sn - rail_array[end_node] = tmp_trans_en + grid_map.grid[start_node] = tmp_trans_sn + grid_map.grid[end_node] = tmp_trans_en - connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added) + connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added) - def connect_random_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added, - inter_connect_max_nbr_of_shortes_city): + def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start_nodes_added: IntVector2DArrayType, + end_nodes_added: IntVector2DArrayType, + nodes_added: IntVector2DArrayType, + inter_connect_max_nbr_of_shortes_city: int): x = np.arange(len(start_nodes_added)) random_city_idx = np.random.choice(x, len(x), False) @@ -295,9 +320,9 @@ def realistic_rail_generator(num_cities=5, for i in range(max_input_output): start_node = s_nodes[idx_s_nodes[i]] end_node = e_nodes[idx_e_nodes[i]] - rail_array[start_node] = 0 - rail_array[end_node] = 0 - connection = connect_nodes(rail_trans, rail_array, start_node, end_node) + grid_map.grid[start_node] = 0 + grid_map.grid[end_node] = 0 + connection = connect_nodes(rail_trans, grid_map, start_node, end_node) if len(connection) > 0: nodes_added.append(start_node) nodes_added.append(end_node) @@ -305,11 +330,10 @@ def realistic_rail_generator(num_cities=5, if print_out_info: print("connect_random_stations : connect_nodes -> no path found") - def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct: + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) - rail_array = grid_map.grid - rail_array.fill(0) + grid_map.grid.fill(0) np.random.seed(seed + num_resets) intern_city_size = city_size @@ -354,23 +378,23 @@ def realistic_rail_generator(num_cities=5, # ---------------------------------------------------------------------------------- # generate city topology nodes_added, train_stations, s_nodes, e_nodes, station_tracks = \ - create_stations_from_city_locations(rail_trans, rail_array, + create_stations_from_city_locations(rail_trans, grid_map, generate_city_locations, intern_max_number_of_station_tracks) # build switches # TODO remove true/false block if True: - create_switches_at_stations(rail_trans, rail_array, station_tracks, nodes_added) + create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added) # ---------------------------------------------------------------------------------- # connect stations # TODO remove true/false block if True: if do_random_connect_stations: - connect_random_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, + connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, inter_connect_max_nbr_of_shortes_city) else: - connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, + connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, inter_connect_max_nbr_of_shortes_city) # ---------------------------------------------------------------------------------- @@ -435,7 +459,7 @@ for itrials in range(1000): max_number_of_station_tracks=np.random.choice(4) + 4, nbr_of_switches_per_station_track=np.random.choice(4) + 2, connect_max_nbr_of_shortes_city=2, - do_random_connect_stations=False and np.random.choice(1) == 0, + do_random_connect_stations=np.random.choice(1) == 0, # Number of cities in map seed=int(time.time()), # Random seed print_out_info=True diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index ab0d1a94cdd7cfbdc2b29afcae52461b30ef0e44..c1bdedebd3edef2c22ef971a538fa55c5c1482a6 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -7,15 +7,20 @@ a GridTransitionMap object. from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid_utils import IntVector2D +from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions -def connect_basic_operation(rail_trans, rail_array, start, end, - flip_start_node_trans=False, flip_end_node_trans=False): +def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start: IntVector2D, + end: IntVector2D, + flip_start_node_trans=False, + flip_end_node_trans=False): """ - Creates a new path [start,end] in rail_array, based on rail_trans. + Creates a new path [start,end] in grid_map, based on rail_trans. """ # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) + path = a_star(rail_trans, grid_map.grid, start, end) if len(path) < 2: return [] current_dir = get_direction(path[0], path[1]) @@ -25,7 +30,7 @@ def connect_basic_operation(rail_trans, rail_array, start, end, new_pos = path[index + 1] new_dir = get_direction(current_pos, new_pos) - new_trans = rail_array[current_pos] + new_trans = grid_map.grid[current_pos] if index == 0: if new_trans == 0: # end-point @@ -42,11 +47,11 @@ def connect_basic_operation(rail_trans, rail_array, start, end, new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) # set the backwards path new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans + grid_map.grid[current_pos] = new_trans if new_pos == end_pos: # setup end pos setup - new_trans_e = rail_array[end_pos] + new_trans_e = grid_map.grid[end_pos] if new_trans_e == 0: # end-point if flip_end_node_trans: @@ -56,23 +61,24 @@ def connect_basic_operation(rail_trans, rail_array, start, end, else: # into existing rail new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e + grid_map.grid[end_pos] = new_trans_e current_dir = new_dir return path -def connect_rail(rail_trans, rail_array, start, end): - return connect_basic_operation(rail_trans, rail_array, start, end, True, True) +def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D): + return connect_basic_operation(rail_trans, grid_map, start, end, True, True) -def connect_nodes(rail_trans, rail_array, start, end): - return connect_basic_operation(rail_trans, rail_array, start, end, False, False) +def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D): + return connect_basic_operation(rail_trans, grid_map, start, end, False, False) -def connect_from_nodes(rail_trans, rail_array, start, end): - return connect_basic_operation(rail_trans, rail_array, start, end, False, True) +def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, + end: IntVector2D): + return connect_basic_operation(rail_trans, grid_map, start, end, False, True) -def connect_to_nodes(rail_trans, rail_array, start, end): - return connect_basic_operation(rail_trans, rail_array, start, end, True, False) +def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D): + return connect_basic_operation(rail_trans, grid_map, start, end, True, False) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a16fb6018a6354665a44c1b44cafd6975bb4e680..d9ed1876e444258924c87874f1fda5ab8616be22 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -121,7 +121,7 @@ def complex_rail_generator(nr_start_goal=1, # we might as well give up at this point break - new_path = connect_rail(rail_trans, rail_array, start, goal) + new_path = connect_rail(rail_trans, grid_map, start, goal) if len(new_path) >= 2: nr_created += 1 start_goal.append([start, goal]) @@ -146,7 +146,7 @@ def complex_rail_generator(nr_start_goal=1, break if not all_ok: break - new_path = connect_rail(rail_trans, rail_array, start, goal) + new_path = connect_rail(rail_trans, grid_map, start, goal) if len(new_path) >= 2: nr_created += 1 @@ -645,7 +645,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 for neighb in connected_neighb_idx: if neighb not in node_stack: node_stack.append(neighb) - connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) + connect_nodes(rail_trans, grid_map, node_positions[current_node], node_positions[neighb]) node_stack.pop(0) # Place train stations close to the node @@ -688,7 +688,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 train_stations[trainstation_node].append((station_x, station_y)) # Connect train station to the correct node - connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], + connection = connect_from_nodes(rail_trans, grid_map, node_positions[trainstation_node], (station_x, station_y)) # Check if connection was made if len(connection) == 0: @@ -723,11 +723,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 width - 2) # Connect train station to the correct node - connect_nodes(rail_trans, rail_array, (intersect_x_1, intersect_y_1), + connect_nodes(rail_trans, grid_map, (intersect_x_1, intersect_y_1), (intersect_x_2, intersect_y_2)) - connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + connect_nodes(rail_trans, grid_map, intersection_positions[intersection], (intersect_x_1, intersect_y_1)) - connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + connect_nodes(rail_trans, grid_map, intersection_positions[intersection], (intersect_x_2, intersect_y_2)) grid_map.fix_transitions((intersect_x_1, intersect_y_1)) grid_map.fix_transitions((intersect_x_2, intersect_y_2))