diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 7f5446a1be3e9bca64caf34c953b3755dca6d5d7..da007717a45293d4426c83aff55856e172182ed9 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -13,6 +13,46 @@ from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool +class PositionOps: + def subtract_pos(nodeA, nodeB): + return (nodeA[0] - nodeB[0], nodeA[1] - nodeB[1]) + + def add_pos(nodeA, nodeB): + return (nodeA[0] + nodeB[0], nodeA[1] + nodeB[1]) + + def make_orthogonal_pos(node): + return (node[1], -node[0]) + + def get_norm_pos(node): + return np.sqrt(node[0] * node[0] + node[1] * node[1]) + + def normalize_pos(node): + n = PositionOps.get_norm_pos(node) + if n > 0.0: + n = 1 / n + return PositionOps.scale_pos(node, n) + + def scale_pos(node, scalar): + return (node[0] * scalar, node[1] * scalar) + + def round_pos(node): + return (int(np.round(node[0])), int(np.round(node[1]))) + + def ceil_pos(node): + return (int(np.ceil(node[0])), int(np.ceil(node[1]))) + + def bound_pos(node, min_value, max_value): + return (max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1]))) + + def rotate_pos(node, rot_in_degree): + alpha = rot_in_degree / 180.0 * np.pi + x0 = node[0] + y0 = node[1] + x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha) + y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha) + return (x1, y1) + + def realistic_rail_generator(num_cities=5, city_size=10, allowed_rotation_angles=[0, 90], @@ -40,49 +80,6 @@ def realistic_rail_generator(num_cities=5, The matrix with the correct 16-bit bitmaps for each cell. """ - class PositionOps: - def subtract_pos(nodeA, nodeB): - return (nodeA[0] - nodeB[0], nodeA[1] - nodeB[1]) - - def add_pos(nodeA, nodeB): - return (nodeA[0] + nodeB[0], nodeA[1] + nodeB[1]) - - def make_orthogonal_pos(node): - return (node[1], -node[0]) - - def get_norm_pos(node): - return np.sqrt(node[0] * node[0] + node[1] * node[1]) - - def normalize_pos(node): - n = PositionOps.get_norm_pos(node) - if n > 0.0: - n = 1 / n - return PositionOps.scale_pos(node, n) - - def scale_pos(node, scalar): - return (node[0] * scalar, node[1] * scalar) - - def round_pos(node): - return (int(np.round(node[0])), int(np.round(node[1]))) - - def ceil_pos(node): - return (int(np.ceil(node[0])), int(np.ceil(node[1]))) - - def bound_pos(node, min_value, max_value): - return (max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1]))) - - def rotate_pos(node, rot_in_degree): - alpha = rot_in_degree / 180.0 * np.pi - x0 = node[0] - y0 = node[1] - x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha) - y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha) - return (x1, y1) - - - - - def do_generate_city_locations(width, height, intern_city_size, intern_max_number_of_station_tracks): X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) @@ -107,9 +104,11 @@ def realistic_rail_generator(num_cities=5, for i in range(len(generate_city_locations)): # station main orientation (horizontal or vertical rot_angle = np.random.choice(allowed_rotation_angles) - add_pos_val = PositionOps.scale_pos(PositionOps.rotate_pos((1, 0), rot_angle), (max(1, (intern_city_size - 3) / 2))) + add_pos_val = PositionOps.scale_pos(PositionOps.rotate_pos((1, 0), rot_angle), + (max(1, (intern_city_size - 3) / 2))) generate_city_locations[i][0] = PositionOps.add_pos(generate_city_locations[i][1], add_pos_val) - add_pos_val = PositionOps.scale_pos(PositionOps.rotate_pos((1, 0), 180 + rot_angle), (max(1, (intern_city_size - 3) / 2))) + add_pos_val = PositionOps.scale_pos(PositionOps.rotate_pos((1, 0), 180 + rot_angle), + (max(1, (intern_city_size - 3) / 2))) generate_city_locations[i][1] = PositionOps.add_pos(generate_city_locations[i][1], add_pos_val) return generate_city_locations @@ -131,10 +130,13 @@ def realistic_rail_generator(num_cities=5, org_start_node = generate_city_locations[city_loop][0] org_end_node = generate_city_locations[city_loop][1] - ortho_trans = PositionOps.make_orthogonal_pos(PositionOps.normalize_pos(PositionOps.subtract_pos(org_start_node, org_end_node))) + ortho_trans = PositionOps.make_orthogonal_pos( + PositionOps.normalize_pos(PositionOps.subtract_pos(org_start_node, org_end_node))) s = (ct - number_of_connecting_tracks / 2.0) - start_node = PositionOps.ceil_pos(PositionOps.add_pos(org_start_node, PositionOps.scale_pos(ortho_trans, s))) - end_node = PositionOps.ceil_pos(PositionOps.add_pos(org_end_node, PositionOps.scale_pos(ortho_trans, s))) + start_node = PositionOps.ceil_pos( + PositionOps.add_pos(org_start_node, PositionOps.scale_pos(ortho_trans, s))) + end_node = PositionOps.ceil_pos( + PositionOps.add_pos(org_end_node, PositionOps.scale_pos(ortho_trans, s))) connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node) if len(connection) > 0: @@ -154,7 +156,7 @@ def realistic_rail_generator(num_cities=5, for city_loop in range(len(switch_slots)): data = switch_slots[city_loop] - data_idx = np.random.choice(np.arange(len(data)),len(data),False) + data_idx = np.random.choice(np.arange(len(data)), len(data), False) for i in range(len(data) - 1): start_node = data[data_idx[i]] end_node = data[data_idx[i + 1]] @@ -178,7 +180,8 @@ def realistic_rail_generator(num_cities=5, b = [] for yLoop in x: for xLoop in x: - v = PositionOps.get_norm_pos(PositionOps.subtract_pos(start_nodes_added[xLoop][0], end_nodes_added[yLoop][0])) + v = PositionOps.get_norm_pos( + PositionOps.subtract_pos(start_nodes_added[xLoop][0], end_nodes_added[yLoop][0])) if v > 0: v = np.inf a[yLoop].append(v) @@ -246,7 +249,6 @@ def realistic_rail_generator(num_cities=5, if print_out_info: print("intern_max_number_of_station_tracks:", intern_max_number_of_station_tracks) - intern_nbr_of_switches_per_station_track = nbr_of_switches_per_station_track if nbr_of_switches_per_station_track < 1: warnings.warn("min intern_nbr_of_switches_per_station_track requried to be > 2!") @@ -254,8 +256,6 @@ def realistic_rail_generator(num_cities=5, if print_out_info: print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track) - - inter_max_number_of_connecting_tracks = max_number_of_connecting_tracks if max_number_of_connecting_tracks < 1: warnings.warn("min inter_max_number_of_connecting_tracks requried to be > 1!")