diff --git a/docs/03_tutorials.rst b/docs/03_tutorials.rst index e862221d8c405cc7000399e4ffd7f092bdc4bc22..c2de2cf174fdc756737938e6f0d239370a5a87d5 100644 --- a/docs/03_tutorials.rst +++ b/docs/03_tutorials.rst @@ -3,3 +3,4 @@ .. include:: tutorials/03_rail_and_schedule_generator.rst .. include:: tutorials/04_stochasticity.rst .. include:: tutorials/05_multispeed.rst +.. include:: tutorials/06_round_2_starter_help.rst diff --git a/docs/tutorials/06_round_2_starter_help.md b/docs/tutorials/06_round_2_starter_help.md new file mode 100644 index 0000000000000000000000000000000000000000..37b7fa6c8195fa1bca7fbf735669996351e34538 --- /dev/null +++ b/docs/tutorials/06_round_2_starter_help.md @@ -0,0 +1,32 @@ +# How to get started in Round 2 + +- [Environment Changes](#environment-changes) +- [Level generation](#level-generation) +- [Observations](#observations) +- [Predictions](#predictions) + +## Environment Changes +There have been some major changes in how agents are being handled in the environment in this Flatland update. +### Agents +Agents are no more permant entities in the environment. Now agents will be removed from the environment as soon as they finsish their task. To keep interactions with the environment as simple as possible we do not modify the dimensions of the observation vectors nor the number of agents. Agents that have finished do not require any special treatment from the controller. Any action provided to these agents is simply ignored, just like before. + +Start positions of agents are *not unique* anymore. This means that many agents can start from the same position on the railway grid. It is important to keep in mind that whatever agent moves first will block the rest of the agents from moving into the same cell. Thus, the controller can already decide the ordering of the agents from the first step. + +## Level Generation +The levels are now generated using the `sparse_rail_generator` and the `sparse_schedule_generator` +### Rail Generation +The rail generation is done in a sequence of steps: +1. A number of city centers are placed in a a grid of size `(height, width)` +2. Each city is connected to two neighbouring cities +3. Internal parallel tracks are generated in each city + + +### Schedule Generation +The `sparse_schedule_generator` produces tasks for the agents by selecting a starting city and a target city. The agent is then placed on an even track number on the starting city and faced such that a path exists to the target city. The task for the agent is to reach the target position as fast as possible. + +In the future we will update how these schedules are generated to allow for more complex tasks + +## Observations +Observations have been updated to reflect the novel features and behaviors of Flatland. Have a look at [observation](https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py) or the documentation for more details on the observations. + +## Predicitons \ No newline at end of file diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9124edd1cae6e7d1f2f3b909c721a8f4ffd35ca3 --- /dev/null +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -0,0 +1,612 @@ +import copy +import os +import time +import warnings + +import numpy as np + +from flatland.core.grid.grid4_utils import mirror +from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail_in_grid_map +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct +from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.utils.rendertools import RenderTool + + +class Vec2dOperations: + def subtract_pos(nodeA, nodeB): + """ + vector operation : nodeA - nodeB + + :param nodeA: tuple with coordinate (x,y) or 2d vector + :param nodeB: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return (nodeA[0] - nodeB[0], nodeA[1] - nodeB[1]) + + def add_pos(nodeA, nodeB): + """ + vector operation : nodeA + nodeB + + :param nodeA: tuple with coordinate (x,y) or 2d vector + :param nodeB: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return (nodeA[0] + nodeB[0], nodeA[1] + nodeB[1]) + + def make_orthogonal_pos(node): + """ + vector operation : rotates the 2D vector +90° + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return (node[1], -node[0]) + + def get_norm_pos(node): + """ + calculates the euclidean norm of the 2d vector + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return np.sqrt(node[0] * node[0] + node[1] * node[1]) + + def normalize_pos(node): + """ + normalize the 2d vector = v/|v| + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + n = Vec2dOperations.get_norm_pos(node) + if n > 0.0: + n = 1 / n + return Vec2dOperations.scale_pos(node, n) + + def scale_pos(node, scalar): + """ + scales the 2d vector = node * scale + + :param node: tuple with coordinate (x,y) or 2d vector + :param scale: scalar to scale + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return (node[0] * scalar, node[1] * scalar) + + def round_pos(node): + """ + rounds the x and y coordinate and convert them to an integer values + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return (int(np.round(node[0])), int(np.round(node[1]))) + + def ceil_pos(node): + """ + ceiling the x and y coordinate and convert them to an integer values + + :param node: tuple with coordinate (x,y) or 2d vector + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return (int(np.ceil(node[0])), int(np.ceil(node[1]))) + + def bound_pos(node, min_value, max_value): + """ + force the values x and y to be between min_value and max_value + + :param node: tuple with coordinate (x,y) or 2d vector + :param min_value: scalar value + :param max_value: scalar value + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + return (max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1]))) + + def rotate_pos(node, rot_in_degree): + """ + rotate the 2d vector with given angle in degree + + :param node: tuple with coordinate (x,y) or 2d vector + :param rot_in_degree: angle in degree + :return: + ------- + tuple with coordinate (x,y) or 2d vector + """ + alpha = rot_in_degree / 180.0 * np.pi + x0 = node[0] + y0 = node[1] + x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha) + y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha) + return (x1, y1) + + +class Grid_Map_Op: + def min_max_cut(min_v, max_v, v): + return max(min_v, min(max_v, v)) + + def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True): + gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC + + lrcStroke = [[Grid_Map_Op.min_max_cut(0, height - 1, pt_from[0]), + Grid_Map_Op.min_max_cut(0, width - 1, pt_from[1])], + [Grid_Map_Op.min_max_cut(0, height - 1, pt_via[0]), + Grid_Map_Op.min_max_cut(0, width - 1, pt_via[1])], + [Grid_Map_Op.min_max_cut(0, height - 1, pt_to[0]), + Grid_Map_Op.min_max_cut(0, width - 1, pt_to[1])]] + + rc3Cells = np.array(lrcStroke[:3]) # the 3 cells + rcMiddle = rc3Cells[1] # the middle cell which we will update + bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2 + + # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East + rc2Trans = np.diff(rc3Cells, axis=0) + + # get the direction index for the 2 transitions + liTrans = [] + for rcTrans in rc2Trans: + # gRCTrans - rcTrans gives an array of vector differences between our rcTrans + # and the 4 directions stored in gRCTrans. + # Where the vector difference is zero, we have a match... + # np.all detects where the whole row,col vector is zero. + # argwhere gives the index of the zero vector, ie the direction index + iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1)) + if len(iTrans) > 0: + iTrans = iTrans[0][0] + liTrans.append(iTrans) + + # check that we have two transitions + if len(liTrans) == 2: + # Set the transition + # Set the transition + # If this transition spans 3 cells, it is not a deadend, so remove any deadends. + # The user will need to resolve any conflicts. + grid_map.set_transition((*rcMiddle, liTrans[0]), + liTrans[1], + bAddRemove, + remove_deadends=not bDeadend) + + # Also set the reverse transition + # use the reversed outbound transition for inbound + # and the reversed inbound transition for outbound + grid_map.set_transition((*rcMiddle, mirror(liTrans[1])), + mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) + + +def realistic_rail_generator(max_num_cities=5, + city_size=10, + allowed_rotation_angles=[0, 90], + max_number_of_station_tracks=4, + nbr_of_switches_per_station_track=2, + connect_max_nbr_of_shortes_city=4, + do_random_connect_stations=False, + seed=0, + print_out_info=True) -> RailGenerator: + """ + This is a level generator which generates a realistic rail configurations + + :param max_num_cities: Number of city node + :param city_size: Length of city measure in cells + :param allowed_rotation_angles: Rotate the city (around center) + :param max_number_of_station_tracks: max number of tracks per station + :param nbr_of_switches_per_station_track: number of switches per track (max) + :param connect_max_nbr_of_shortes_city: max number of connecting track between stations + :param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand + :param seed: Random Seed + :print_out_info : print debug info + :return: + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def do_generate_city_locations(width, height, intern_city_size, intern_max_number_of_station_tracks): + + X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) + Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) + + max_num_cities = min(max_num_cities, X * Y) + + cities_at = np.random.choice(X * Y, max_num_cities, False) + cities_at = np.sort(cities_at) + if print_out_info: + print("max nbr of cities with given configuration is:", max_num_cities) + + x = np.floor(cities_at / Y) + y = cities_at - x * Y + xs = (x * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2 + ys = (y * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2 + + generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))] + return generate_city_locations, max_num_cities + + def do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles): + for i in range(len(generate_city_locations)): + # station main orientation (horizontal or vertical + rot_angle = np.random.choice(allowed_rotation_angles) + add_pos_val = Vec2dOperations.scale_pos(Vec2dOperations.rotate_pos((1, 0), rot_angle), + (max(1, (intern_city_size - 3) / 2))) + generate_city_locations[i][0] = Vec2dOperations.add_pos(generate_city_locations[i][1], add_pos_val) + add_pos_val = Vec2dOperations.scale_pos(Vec2dOperations.rotate_pos((1, 0), 180 + rot_angle), + (max(1, (intern_city_size - 3) / 2))) + generate_city_locations[i][1] = Vec2dOperations.add_pos(generate_city_locations[i][1], add_pos_val) + return generate_city_locations + + def create_stations_from_city_locations(rail_trans, rail_array, generate_city_locations, + intern_max_number_of_station_tracks): + nodes_added = [] + start_nodes_added = [[] for i in range(len(generate_city_locations))] + end_nodes_added = [[] for i in range(len(generate_city_locations))] + station_slots = [[] for i in range(len(generate_city_locations))] + station_tracks = [[[] for j in range(intern_max_number_of_station_tracks)] for i in range(len( + generate_city_locations))] + + station_slots_cnt = 0 + + for city_loop in range(len(generate_city_locations)): + # Connect train station to the correct node + number_of_connecting_tracks = np.random.choice(max(0, intern_max_number_of_station_tracks)) + 1 + for ct in range(number_of_connecting_tracks): + org_start_node = generate_city_locations[city_loop][0] + org_end_node = generate_city_locations[city_loop][1] + + ortho_trans = Vec2dOperations.make_orthogonal_pos( + Vec2dOperations.normalize_pos(Vec2dOperations.subtract_pos(org_start_node, org_end_node))) + s = (ct - number_of_connecting_tracks / 2.0) + start_node = Vec2dOperations.ceil_pos( + Vec2dOperations.add_pos(org_start_node, Vec2dOperations.scale_pos(ortho_trans, s))) + end_node = Vec2dOperations.ceil_pos( + Vec2dOperations.add_pos(org_end_node, Vec2dOperations.scale_pos(ortho_trans, s))) + + connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + + start_nodes_added[city_loop].append(start_node) + end_nodes_added[city_loop].append(end_node) + + # place in the center of path a station slot + station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) + station_slots_cnt += 1 + + station_tracks[city_loop][ct] = connection + + if print_out_info: + print("max nbr of station slots with given configuration is:", station_slots_cnt) + + return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks + + def create_switches_at_stations(width, height, grid_map, station_tracks, nodes_added, + intern_nbr_of_switches_per_station_track): + # generate switch based on switch slot list and connect them + for city_loop in range(len(station_tracks)): + datas = station_tracks[city_loop] + for data_loop in range(len(datas) - 1): + data = datas[data_loop] + data1 = datas[data_loop + 1] + if len(data) > 2 and len(data1) > 2: + for i in np.random.choice(min(len(data1), len(data)) - 2, + intern_nbr_of_switches_per_station_track): + Grid_Map_Op.add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True) + nodes_added.append(data[i + 1]) + nodes_added.append(data1[i + 1]) + nodes_added.append(data1[i + 2]) + + return nodes_added + + def calc_nbr_of_graphs(graph): + for i in range(len(graph)): + for j in range(len(graph)): + a = graph[i] + b = graph[j] + connected = False + if a[0] == b[0] or a[1] == b[0]: + connected = True + if a[0] == b[1] or a[1] == b[1]: + connected = True + + if connected: + a = [graph[i][0], graph[i][1], graph[i][2]] + b = [graph[j][0], graph[j][1], graph[j][2]] + graph[i] = (graph[i][0], graph[i][1], min(np.min(a), np.min(b))) + graph[j] = (graph[j][0], graph[j][1], min(np.min(a), np.min(b))) + else: + a = [graph[i][0], graph[i][1], graph[i][2]] + graph[i] = (graph[i][0], graph[i][1], np.min(a)) + b = [graph[j][0], graph[j][1], graph[j][2]] + graph[j] = (graph[j][0], graph[j][1], np.min(b)) + + graph_ids = [] + for i in range(len(graph)): + graph_ids.append(graph[i][2]) + if print_out_info: + print("************* NBR of graphs:", len(np.unique(graph_ids))) + return graph, np.unique(graph_ids).astype(int) + + def connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added): + _, graphids = calc_nbr_of_graphs(city_edges) + if len(graphids) > 0: + for i in range(len(graphids) - 1): + connection = [] + cnt = 0 + while len(connection) == 0 and cnt < 100: + s_nodes = copy.deepcopy(org_s_nodes) + e_nodes = copy.deepcopy(org_e_nodes) + start_nodes = s_nodes[graphids[i]] + end_nodes = e_nodes[graphids[i + 1]] + start_node = start_nodes[np.random.choice(len(start_nodes))] + end_node = end_nodes[np.random.choice(len(end_nodes))] + rail_array[start_node] = 0 + rail_array[end_node] = 0 + connection = connect_rail_in_grid_map(rail_array, start_node, end_node, rail_trans) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + cnt += 1 + + def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added, + inter_connect_max_nbr_of_shortes_city): + + city_edges = [] + + s_nodes = copy.deepcopy(org_s_nodes) + e_nodes = copy.deepcopy(org_e_nodes) + + for k in range(inter_connect_max_nbr_of_shortes_city): + for city_loop in range(len(s_nodes)): + sns = s_nodes[city_loop] + for start_node in sns: + min_distance = np.inf + end_node = None + for city_loop_find_shortest in range(len(e_nodes)): + if city_loop_find_shortest == city_loop: + continue + ens = e_nodes[city_loop_find_shortest] + for en in ens: + d = Vec2dOperations.get_norm_pos(Vec2dOperations.subtract_pos(en, start_node)) + if d < min_distance: + min_distance = d + end_node = en + cl = city_loop_find_shortest + + if end_node is not None: + tmp_trans_sn = rail_array[start_node] + tmp_trans_en = rail_array[end_node] + rail_array[start_node] = 0 + rail_array[end_node] = 0 + connection = connect_rail_in_grid_map(rail_array, start_node, end_node, rail_trans) + if len(connection) > 0: + s_nodes[city_loop].remove(start_node) + e_nodes[cl].remove(end_node) + a = (city_loop, cl, np.inf) + if city_loop > cl: + a = (cl, city_loop, np.inf) + if not (a in city_edges): + city_edges.append(a) + nodes_added.append(start_node) + nodes_added.append(end_node) + else: + rail_array[start_node] = tmp_trans_sn + rail_array[end_node] = tmp_trans_en + + connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added) + + def connect_random_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added, + inter_connect_max_nbr_of_shortes_city): + x = np.arange(len(start_nodes_added)) + random_city_idx = np.random.choice(x, len(x), False) + + # cyclic connection + random_city_idx = np.append(random_city_idx, random_city_idx[0]) + + for city_loop in range(len(random_city_idx) - 1): + idx_a = random_city_idx[city_loop + 1] + idx_b = random_city_idx[city_loop] + s_nodes = start_nodes_added[idx_a] + e_nodes = end_nodes_added[idx_b] + + max_input_output = max(len(s_nodes), len(e_nodes)) + max_input_output = min(inter_connect_max_nbr_of_shortes_city, max_input_output) + + if do_random_connect_stations: + idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) + idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False) + else: + idx_s_nodes = np.arange(len(s_nodes)) + idx_e_nodes = np.arange(len(e_nodes)) + + if len(idx_s_nodes) < max_input_output: + idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len( + idx_s_nodes))) + if len(idx_e_nodes) < max_input_output: + idx_e_nodes = np.append(idx_e_nodes, + np.random.choice(np.arange(len(idx_e_nodes)), max_input_output - len( + idx_e_nodes))) + + if len(idx_s_nodes) > inter_connect_max_nbr_of_shortes_city: + idx_s_nodes = np.random.choice(idx_s_nodes, inter_connect_max_nbr_of_shortes_city, False) + if len(idx_e_nodes) > inter_connect_max_nbr_of_shortes_city: + idx_e_nodes = np.random.choice(idx_e_nodes, inter_connect_max_nbr_of_shortes_city, False) + + for i in range(max_input_output): + start_node = s_nodes[idx_s_nodes[i]] + end_node = e_nodes[idx_e_nodes[i]] + new_trans = rail_array[start_node] = 0 + new_trans = rail_array[end_node] = 0 + connection = connect_nodes(rail_trans, rail_array, start_node, end_node) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + + def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct: + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + np.random.seed(seed + num_resets) + + intern_city_size = city_size + if city_size < 3: + warnings.warn("min city_size requried to be > 3!") + intern_city_size = 3 + if print_out_info: + print("intern_city_size:", intern_city_size) + + intern_max_number_of_station_tracks = max_number_of_station_tracks + if max_number_of_station_tracks < 1: + warnings.warn("min max_number_of_station_tracks requried to be > 1!") + intern_max_number_of_station_tracks = 1 + if print_out_info: + print("intern_max_number_of_station_tracks:", intern_max_number_of_station_tracks) + + intern_nbr_of_switches_per_station_track = nbr_of_switches_per_station_track + if nbr_of_switches_per_station_track < 1: + warnings.warn("min intern_nbr_of_switches_per_station_track requried to be > 2!") + intern_nbr_of_switches_per_station_track = 2 + if print_out_info: + print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track) + + inter_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city + if connect_max_nbr_of_shortes_city < 1: + warnings.warn("min inter_connect_max_nbr_of_shortes_city requried to be > 1!") + inter_connect_max_nbr_of_shortes_city = 1 + if print_out_info: + print("inter_connect_max_nbr_of_shortes_city:", inter_connect_max_nbr_of_shortes_city) + + agent_start_targets_nodes = [] + + # ---------------------------------------------------------------------------------- + # generate city locations + generate_city_locations, max_num_cities = do_generate_city_locations(width, height, intern_city_size, + intern_max_number_of_station_tracks) + + # ---------------------------------------------------------------------------------- + # apply orientation to cities (horizontal, vertical) + generate_city_locations = do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles) + + # ---------------------------------------------------------------------------------- + # generate city topology + nodes_added, train_stations, s_nodes, e_nodes, station_tracks = \ + create_stations_from_city_locations(rail_trans, rail_array, + generate_city_locations, + intern_max_number_of_station_tracks) + # build switches + create_switches_at_stations(width, height, grid_map, station_tracks, nodes_added, + intern_nbr_of_switches_per_station_track) + + # ---------------------------------------------------------------------------------- + # connect stations + if True: + if do_random_connect_stations: + connect_random_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, + inter_connect_max_nbr_of_shortes_city) + else: + connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added, + inter_connect_max_nbr_of_shortes_city) + + # ---------------------------------------------------------------------------------- + # fix all transition at starting / ending points (mostly add a dead end, if missing) + for i in range(len(nodes_added)): + grid_map.fix_transitions(nodes_added[i]) + + # ---------------------------------------------------------------------------------- + # Slot availability in node + node_available_start = [] + node_available_target = [] + for node_idx in range(max_num_cities): + node_available_start.append(len(train_stations[node_idx])) + node_available_target.append(len(train_stations[node_idx])) + + # Assign agents to slots + for agent_idx in range(num_agents): + avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] + avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] + if len(avail_target_nodes) == 0: + num_agents -= 1 + continue + start_node = np.random.choice(avail_start_nodes) + target_node = np.random.choice(avail_target_nodes) + tries = 0 + found_agent_pair = True + while target_node == start_node: + target_node = np.random.choice(avail_target_nodes) + tries += 1 + # Test again with new start node if no pair is found (This code needs to be improved) + if (tries + 1) % 10 == 0: + start_node = np.random.choice(avail_start_nodes) + if tries > 100: + warnings.warn("Could not set trainstations, removing agent!") + found_agent_pair = False + break + if found_agent_pair: + node_available_start[start_node] -= 1 + node_available_target[target_node] -= 1 + agent_start_targets_nodes.append((start_node, target_node)) + else: + num_agents -= 1 + + return grid_map, {'agents_hints': { + 'num_agents': num_agents, + 'agent_start_targets_nodes': agent_start_targets_nodes, + 'train_stations': train_stations + }} + + return generator + + +for itrials in range(100): + print(itrials, "generate new city") + np.random.seed(int(time.time())) + env = RailEnv(width=40 + np.random.choice(100), + height=40 + np.random.choice(100), + rail_generator=realistic_rail_generator(max_num_cities=2 + np.random.choice(10), + city_size=10 + np.random.choice(10), + allowed_rotation_angles=[-90, -45, 0, 45, 90], + max_number_of_station_tracks=np.random.choice(4) + 4, + nbr_of_switches_per_station_track=np.random.choice(4) + 2, + connect_max_nbr_of_shortes_city=np.random.choice(4) + 2, + do_random_connect_stations=np.random.choice(1) == 0, + # Number of cities in map + seed=int(time.time()), # Random seed + print_out_info=False + ), + schedule_generator=sparse_schedule_generator(), + number_of_agents=1000, + obs_builder_object=GlobalObsForRailEnv()) + + # reset to initialize agents_static + env_renderer = RenderTool(env, gl="PILSVG", screen_width=1400, screen_height=1000) + cnt = 0 + while cnt < 10: + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + cnt += 1 + + env_renderer.gl.save_image( + os.path.join( + "./../render_output/", + "flatland_frame_{:04d}_{:04d}.png".format(itrials, 0) + )) + + env_renderer.close_window() diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index da8cacef2a3e9972ef4b90094c59f448ca07bc37..25e68d3e907f1c897145108d2cffae25d7fc4e91 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,3 +1,5 @@ +import time + import numpy as np from flatland.envs.observations import TreeObsForRailEnv @@ -5,7 +7,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator -from flatland.utils.rendertools import RenderTool +from flatland.utils.rendertools import RenderTool, AgentRenderVariant np.random.seed(1) @@ -13,7 +15,7 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents +stochastic_data = {'prop_malfunction': 0.3, # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction @@ -30,22 +32,28 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) - num_intersections=10, # Number of intersections (no start / target) - num_trainstations=50, # Number of possible start/targets on map - min_node_dist=3, # Minimal distance of nodes - node_radius=4, # Proximity of stations to city center - num_neighb=4, # Number of connections to other cities/intersections - seed=15, # Random seed - grid_mode=True, - enhance_intersection=False + rail_generator=sparse_rail_generator(max_num_cities=10, + # Number of cities in map (where train stations are) + seed=1, # Random seed + grid_mode=False, + max_rails_between_cities=2, + max_rails_in_city=4, ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=20, stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=TreeObservation) + obs_builder_object=TreeObservation, + remove_agents_at_target=True + ) + +# RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0) -env_renderer = RenderTool(env, gl="PILSVG", ) + +env_renderer = RenderTool(env, gl="PILSVG", + agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX, + show_debug=True, + screen_height=1000, + screen_width=1000) # Import your own Agent or use RLlib to train agents on Flatland @@ -61,7 +69,7 @@ class RandomAgent: :param state: input is the observation of the agent :return: returns an action """ - return np.random.choice(np.arange(self.action_size)) + return 2 # np.random.choice(np.arange(self.action_size)) def step(self, memories): """ @@ -90,8 +98,11 @@ action_dict = dict() print("Start episode...") # Reset environment and get initial observations for all agents +start_reset = time.time() obs = env.reset() - +end_reset = time.time() +print(end_reset - start_reset) +print(env.get_num_agents(), ) # Reset the rendering sytem env_renderer.reset() diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 61ca23840d6c59e75d75d5a15fa9fb8118f45b95..61b83c3b9ec2d5f63203897aa52c87d27c4460bd 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -13,7 +13,7 @@ np.random.seed(1) env = RailEnv(width=7, height=7, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) diff --git a/examples/simple_example_city_railway_generator.py b/examples/simple_example_city_railway_generator.py deleted file mode 100644 index 4182bc27128469c72ba8e2246645d629070eaf20..0000000000000000000000000000000000000000 --- a/examples/simple_example_city_railway_generator.py +++ /dev/null @@ -1,62 +0,0 @@ -import os - -import numpy as np - -from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d -from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators_city_generator import city_generator -from flatland.envs.schedule_generators import city_schedule_generator -from flatland.utils.rendertools import RenderTool, AgentRenderVariant - -OUTPUT_DIR = "./../render_output/" -if not os.path.exists(OUTPUT_DIR): - os.mkdir(OUTPUT_DIR) - -for itrials in np.arange(1, 15, 1): - print(itrials, "generate new city") - - # init seed - np.random.seed(itrials) - - # select distance function used in a-star path finding - dist_fun = Vec2d.get_manhattan_distance - dfsel = (itrials - 1) % 3 - if dfsel == 1: - dist_fun = Vec2d.get_euclidean_distance - elif dfsel == 2: - dist_fun = Vec2d.get_chebyshev_distance - - # create RailEnv and use the city_generator to create a map - env = RailEnv(width=40 + np.random.choice(100), - height=40 + np.random.choice(100), - rail_generator=city_generator(num_cities=5 + np.random.choice(10), - city_size=10 + np.random.choice(5), - allowed_rotation_angles=np.arange(0, 360, 6), - max_number_of_station_tracks=4 + np.random.choice(4), - nbr_of_switches_per_station_track=2 + np.random.choice(2), - connect_max_nbr_of_shortes_city=2 + np.random.choice(4), - do_random_connect_stations=itrials % 2 == 0, - a_star_distance_function=dist_fun, - seed=itrials, - print_out_info=False - ), - schedule_generator=city_schedule_generator(), - number_of_agents=10000, - obs_builder_object=GlobalObsForRailEnv()) - - # reset to initialize agents_static - env_renderer = RenderTool(env, gl="PILSVG", screen_width=1400, screen_height=1000, - agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) - - env_renderer.render_env(show=True, show_observations=False, show_predictions=False) - - # store rendered file into render_output if the path exists - env_renderer.gl.save_image( - os.path.join( - OUTPUT_DIR, - "flatland_frame_{:04d}.png".format(itrials) - )) - - # close the renderer / window - env_renderer.close_window() diff --git a/examples/training_example.py b/examples/training_example.py index df93479f5a5ee05abfcb1a98b07ef052bffc2bd4..78c0299d4cee8ae588bcf8e9e7559ff1c8364c26 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -16,7 +16,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=0), schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index 3a75aa81193d2355f71a05d8825bc64da4547f6f..03917a05f0940f6185d1cece439d34b11c2bcc10 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,5 +1,3 @@ -import numpy as np - from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance from flatland.core.grid.grid_utils import IntVector2DArray from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d @@ -37,9 +35,24 @@ class AStarNode: self.f = other.f -def a_star(grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: +def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, + respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: + """ + + :param grid_map: Grid Map where the path is found in + :param start: Start positions as (row,column) + :param end: End position as (row,column) + :param a_star_distance_function: Define the distance function to use as heuristc: + -get_euclidean_distance + -get_manhattan_distance + -get_chebyshev_distance + :param respect_transition_validity: Whether or not a-star respect allowed transitions on the grid map. + - True: Respects the validity of transition. This generates valid paths, of no path if it cannot be found + - False: This always finds a path, but the path might be illegal and thus needs to be fixed afterwards + :param forbidden_cells: List of cells where the path cannot pass through. Used to avoid certain areas of Grid map + :return: IF a path is found a ordered list of al cells in path is returned + """ """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. @@ -93,11 +106,18 @@ def a_star(grid_map: GridTransitionMap, continue # validate positions - if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos): + # + if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, + end_node.pos) and respect_transition_validity: continue - # create new node new_node = AStarNode(node_pos, current_node) + + # Skip paths through forbidden regions if they are provided + if forbidden_cells is not None: + if node_pos in forbidden_cells and new_node != start_node and new_node != end_node: + continue + children.append(new_node) # loop through children diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py index 75cef7b4d3aea783140a5c08c3498a0bc321fb62..1475589ef62863fcb98a2d238b4b4c5dbe078b3c 100644 --- a/flatland/core/grid/grid4_utils.py +++ b/flatland/core/grid/grid4_utils.py @@ -1,3 +1,5 @@ +import numpy as np + from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid_utils import IntVector2D @@ -34,3 +36,25 @@ def get_new_position(position, movement): return (position[0] + 1, position[1]) elif movement == Grid4TransitionsEnum.WEST: return (position[0], position[1] - 1) + + +def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum: + """ + Returns the closest direction orientation of position 2 relative to position 1 + :param pos1: position we are interested in + :param pos2: position we want to know it is facing + :return: direction NESW as int N:0 E:1 S:2 W:3 + """ + diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1])) + axis = np.argmax(np.power(diff_vec, 2)) + direction = np.sign(diff_vec[axis]) + if axis == 0: + if direction > 0: + return Grid4TransitionsEnum.NORTH + else: + return Grid4TransitionsEnum.SOUTH + else: + if direction > 0: + return Grid4TransitionsEnum.WEST + else: + return Grid4TransitionsEnum.EAST diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py index 6004c53567f46feb30e283f0b9dbd946ee8f1375..56858185b37c01822c623232964c285c1ffe52a7 100644 --- a/flatland/core/grid/grid_utils.py +++ b/flatland/core/grid/grid_utils.py @@ -2,6 +2,8 @@ from typing import Tuple, Callable, List, Type import numpy as np +from flatland.core.grid.grid4 import Grid4TransitionsEnum + Vector2D: Type = Tuple[float, float] IntVector2D: Type = Tuple[int, int] @@ -289,5 +291,8 @@ def coordinate_to_position(depth, coords): return position -def distance_on_rail(pos1, pos2): - return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) +def distance_on_rail(pos1, pos2, metric="Euclidean"): + if metric == "Euclidean": + return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2)) + if metric == "Manhattan": + return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1]) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 9db7f3c7775a01824a849d8dc126fbbb3955d212..2a3db9df346faea5a8734e3cb4a78194133cb319 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -14,6 +14,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions from flatland.utils.ordered_set import OrderedSet + # TODO are these general classes or for grid4 only? class TransitionMap: """ @@ -422,6 +423,28 @@ class GridTransitionMap(TransitionMap): continue else: return False + # If the cell is empty but has incoming connections we return false + if binTrans < 1: + connected = 0 + + for iDirOut in np.arange(4): + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then ignore it for the count of incoming connections + if np.any(gPos2 < 0): + continue + if np.any(gPos2 >= grcMax): + continue + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + + for orientation in range(4): + connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut)) + if connected > 0: + return False return True @@ -477,7 +500,7 @@ class GridTransitionMap(TransitionMap): return True - def fix_transitions(self, rcPos: IntVector2DArray): + def fix_transitions(self, rcPos: IntVector2DArray, direction: IntVector2D = -1): """ Fixes broken transitions """ @@ -492,9 +515,8 @@ class GridTransitionMap(TransitionMap): simple_switch_west_south = transitions.rotate_transition(cells[2], 270) symmetrical = cells[6] double_slip = cells[5] - three_way_transitions = [simple_switch_east_south, simple_switch_west_south, symmetrical] + three_way_transitions = [simple_switch_east_south, simple_switch_west_south] # loop over available outbound directions (indices) for rcPos - self.set_transitions(rcPos, 0) incoming_connections = np.zeros(4) for iDirOut in np.arange(4): @@ -517,21 +539,38 @@ class GridTransitionMap(TransitionMap): incoming_connections[iDirOut] = 1 number_of_incoming = np.sum(incoming_connections) - # Only one incoming direction --> Straight line + # Only one incoming direction --> Straight line set deadend if number_of_incoming == 1: - for direction in range(4): - if incoming_connections[direction] > 0: - self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) + if self.get_full_transitions(*rcPos) == 0: + self.set_transitions(rcPos, 0) + else: + self.set_transitions(rcPos, 0) + + for direction in range(4): + if incoming_connections[direction] > 0: + self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) # Connect all incoming connections if number_of_incoming == 2: + self.set_transitions(rcPos, 0) + connect_directions = np.argwhere(incoming_connections > 0) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) # Find feasible connection for three entries if number_of_incoming == 3: - transition = np.random.choice(three_way_transitions, 1) + self.set_transitions(rcPos, 0) hole = np.argwhere(incoming_connections < 1)[0][0] + if direction > 0: + switch_type_idx = (direction - hole + 3) % 4 + if switch_type_idx == 2: + transition = simple_switch_west_south + if switch_type_idx == 0: + transition = simple_switch_east_south + else: + transition = np.random.choice(three_way_transitions, 1) + else: + transition = np.random.choice(three_way_transitions, 1) transition = transitions.rotate_transition(transition, int(hole * 90)) self.set_transitions((rcPos[0], rcPos[1]), transition) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index fce3ffdf320a3c38d7f0551151ffdc8debe6ab5d..0b866e3348728814efc721fc13a15544bde0878d 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -5,29 +5,41 @@ Generator functions are functions that take width, height and num_resets as argu a GridTransitionMap object. """ +import numpy as np + +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_astar import a_star -from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions -def connect_basic_operation( - rail_trans: RailEnvTransitions, - grid_map: GridTransitionMap, - start: IntVector2D, - end: IntVector2D, - flip_start_node_trans=False, - flip_end_node_trans=False, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: +def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, + rail_trans: RailEnvTransitions, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, + flip_start_node_trans: bool = False, flip_end_node_trans: bool = False, + respect_transition_validity: bool = True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: """ - Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and + Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and returns the path created as a list of positions. + :param rail_trans: basic rail transition object + :param grid_map: grid map + :param start: start position of rail + :param end: end position of rail + :param flip_start_node_trans: make valid start position by adding dead-end, empty start if False + :param flip_end_node_trans: make valid end position by adding dead-end, empty end if False + :param respect_transition_validity: Only draw rail maps if legal rail elements can be use, False, draw line without respecting rail transitions. + :param a_star_distance_function: Define what distance function a-star should use + :param forbidden_cells: cells to avoid when drawing rail. Rail cannot go through this list of cells + :return: List of cells in the path """ - # in the worst case we will need to do a A* search, so we might as well set that up - path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function) + + path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, respect_transition_validity, + forbidden_cells) if len(path) < 2: return [] + current_dir = get_direction(path[0], path[1]) end_pos = path[-1] for index in range(len(path) - 1): @@ -72,26 +84,44 @@ def connect_basic_operation( return path -def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function) +def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, + end: IntVector2D, rail_trans: RailEnvTransitions) -> IntVector2DArray: + """ + Generates a straight rail line from start cell to end cell. + Diagonal lines are not allowed + :param rail_trans: + :param grid_map: + :param start: Cell coordinates for start of line + :param end: Cell coordinates for end of line + :return: A list of all cells in the path + """ + + if not (start[0] == end[0] or start[1] == end[1]): + print("No straight line possible!") + return [] + direction = direction_to_point(start, end) -def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function) + if direction is Grid4TransitionsEnum.NORTH or direction is Grid4TransitionsEnum.SOUTH: + start_row = min(start[0], end[0]) + end_row = max(start[0], end[0]) + 1 + rows = np.arange(start_row, end_row) + length = np.abs(end[0] - start[0]) + 1 + cols = np.repeat(start[1], length) + else: # Grid4TransitionsEnum.EAST or Grid4TransitionsEnum.WEST + start_col = min(start[1], end[1]) + end_col = max(start[1], end[1]) + 1 + cols = np.arange(start_col, end_col) + length = np.abs(end[1] - start[1]) + 1 + rows = np.repeat(start[0], length) -def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance - ) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function) + path = list(zip(rows, cols)) + for cell in path: + transition = grid_map.grid[cell] + transition = rail_trans.set_transition(transition, direction, direction, 1) + transition = rail_trans.set_transition(transition, mirror(direction), mirror(direction), 1) + grid_map.grid[cell] = transition -def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function) + return path diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 5cc6a8c11c2b7d5045d66a820f312dc0fd61a492..dc640b14d54a7a730a023bf931167f712dd16de0 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -67,8 +67,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: - # TODO hacky hacky: `range(len(self.predictions[0]))` does not seem safe!! - for t in range(len(self.predictions[0])): + for t in range(self.predictor.max_depth+1): pos_list = [] dir_list = [] for a in handles: @@ -190,15 +189,15 @@ class TreeObsForRailEnv(ObservationBuilder): agent = self.env.agents[handle] # TODO: handle being treated as index if agent.status == RailAgentStatus.READY_TO_DEPART: - _agent_initial_position = agent.initial_position + agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: - _agent_initial_position = agent.position + agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: - _agent_initial_position = agent.target + agent_virtual_position = agent.target else: return None - possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction) + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Here information about the agent itself is stored @@ -208,7 +207,7 @@ class TreeObsForRailEnv(ObservationBuilder): dist_other_agent_encountered=0, dist_potential_conflict=0, dist_unusable_switch=0, dist_to_next_branch=0, dist_min_to_target=distance_map[ - (handle, *_agent_initial_position, + (handle, *agent_virtual_position, agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], @@ -229,7 +228,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): if possible_transitions[branch_direction]: - new_cell = get_new_position(_agent_initial_position, branch_direction) + new_cell = get_new_position(agent_virtual_position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) @@ -294,18 +293,24 @@ class TreeObsForRailEnv(ObservationBuilder): if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction - other_agent_same_direction += 1 + other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed - if self.location_has_agent_direction[position] != direction: - # Cummulate the number of agents on branch with other direction - other_agent_opposite_direction += 1 + # Other direction agents + # TODO: Test that this behavior is as expected + other_agent_opposite_direction += \ + self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction), + 0) - # Check number of possible transitions for agent and total number of transitions in cell (type) + else: + # If no agent in the same direction was found all agents in that position are other direction + other_agent_opposite_direction += self.location_has_agent[position] + + # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) total_transitions = transition_bit.count("1") @@ -528,15 +533,15 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16),\ assuming 16 bits encoding of transitions. - - A 3D array (map_height, map_width, 4) with + - obs_agents_state: A 3D array (map_height, map_width, 5) with - first channel containing the agents position and direction - - second channel containing the other agents positions and diretion + - second channel containing the other agents positions and direction - third channel containing agent/other agent malfunctions - fourth channel containing agent/other agent fractional speeds - fifth channel containing number of other agents ready to depart - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ - target and the positions of the other agents targets. + - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ + target and the positions of the other agents targets (flag only, no counter!). """ def __init__(self): @@ -557,34 +562,44 @@ class GlobalObsForRailEnv(ObservationBuilder): agent = self.env.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART: - _agent_initial_position = agent.initial_position + agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: - _agent_initial_position = agent.position + agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: - _agent_initial_position = agent.target + agent_virtual_position = agent.target else: return None obs_targets = np.zeros((self.env.height, self.env.width, 2)) obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1 - obs_agents_state[_agent_initial_position][0] = agent.direction + # TODO can we do this more elegantly? + for r in range(self.env.height): + for c in range(self.env.width): + obs_agents_state[(r, c)][4] = 0 + + obs_agents_state[agent_virtual_position][0] = agent.direction obs_targets[agent.target][0] = 1 for i in range(len(self.env.agents)): other_agent: EnvAgent = self.env.agents[i] - # ignore other_agent if it is not in the grid - if other_agent.position is None: - continue - if i != handle: - obs_agents_state[other_agent.position][1] = other_agent.direction - obs_targets[other_agent.target][1] = 1 - if other_agent.status == RailAgentStatus.READY_TO_DEPART: - obs_agents_state[other_agent.initial_position] += 1 - obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] - obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + obs_targets[other_agent.target][1] = 1 + + # second to fourth channel only if in the grid + if other_agent.position is not None: + # second channel only for other agents + if i != handle: + obs_agents_state[other_agent.position][1] = other_agent.direction + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] + # fifth channel: all ready to depart on this position + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + obs_agents_state[other_agent.initial_position][4] += 1 return self.rail_obs, obs_agents_state, obs_targets diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 29b6947c28fa054cd1f4c44a204c740e4f536181..aec38ee3926605be6b82779baf62ea6613dbb464 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -52,10 +52,10 @@ class DummyPredictorForRailEnv(PredictionBuilder): # TODO make this generic continue action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] - _agent_initial_position = agent.position - _agent_initial_direction = agent.direction + agent_virtual_position = agent.position + agent_virtual_direction = agent.direction prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0] for index in range(1, self.max_depth + 1): action_done = False # if we're at the target, stop moving... @@ -77,8 +77,8 @@ class DummyPredictorForRailEnv(PredictionBuilder): if not action_done: raise Exception("Cannot move further. Something is wrong") prediction_dict[agent.handle] = prediction - agent.position = _agent_initial_position - agent.direction = _agent_initial_direction + agent.position = agent_virtual_position + agent.direction = agent_virtual_direction return prediction_dict @@ -128,20 +128,20 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): for agent in agents: if agent.status == RailAgentStatus.READY_TO_DEPART: - _agent_initial_position = agent.initial_position + agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: - _agent_initial_position = agent.position + agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: - _agent_initial_position = agent.target + agent_virtual_position = agent.target else: prediction_dict[agent.handle] = None continue - _agent_initial_direction = agent.direction + agent_virtual_direction = agent.direction agent_speed = agent.speed_data["speed"] times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0] shortest_path = shortest_paths[agent.handle] @@ -149,8 +149,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if shortest_path: shortest_path = shortest_path[1:] - new_direction = _agent_initial_direction - new_position = _agent_initial_position + new_direction = agent_virtual_direction + new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): # if we're at the target or not moving, stop moving until max_depth is reached diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1b5ca23f36e3571aa75d9fd69431c119054ab05c..ff741b625091831787a44bd8e6f77aee237bc101 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -116,7 +116,8 @@ class RailEnv(Environment): number_of_agents=1, obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2), max_episode_steps=None, - stochastic_data=None + stochastic_data=None, + remove_agents_at_target=False ): """ Environment init. @@ -147,6 +148,9 @@ class RailEnv(Environment): ObservationBuilder-derived object that takes builds observation vectors for each agent. max_episode_steps : int or None + remove_agents_at_target : bool + If remove_agents_at_target is set to true then the agents will be removed by placing to + RailEnv.DEPOT_POSITION when the agent has reach it's target position. """ super().__init__() @@ -157,6 +161,8 @@ class RailEnv(Environment): self.width = width self.height = height + self.remove_agents_at_target = remove_agents_at_target + self.rewards = [0] * number_of_agents self.done = False self.obs_builder = obs_builder_object @@ -253,6 +259,7 @@ class RailEnv(Environment): rc_pos = (r, c) check = self.rail.cell_neighbours_valid(rc_pos, True) if not check: + print(self.rail.grid[rc_pos]) warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 # hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by @@ -375,10 +382,9 @@ class RailEnv(Environment): self._step_agent(i_agent, action_dict_.get(i_agent)) # Check for end of episode + set global reward to all rewards! - if np.all([np.array_equal(agent.position, agent.target) for agent in self.agents]): + if np.all([agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED] for agent in self.agents]): self.dones["__all__"] = True self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} - if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True for i in range(self.get_num_agents()): @@ -387,7 +393,8 @@ class RailEnv(Environment): info_dict = { 'action_required': { - i: (agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0) + i: (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)) for i, agent in enumerate(self.agents)}, 'malfunction': { i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) @@ -412,7 +419,7 @@ class RailEnv(Environment): """ agent = self.agents[i_agent] - if agent.status == RailAgentStatus.DONE: # this agent has already completed... + if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed... return # agent gets active by a MOVE_* action and if c @@ -519,6 +526,10 @@ class RailEnv(Environment): agent.status = RailAgentStatus.DONE self.dones[i_agent] = True agent.moving = False + + if self.remove_agents_at_target: + agent.position = None + agent.status = RailAgentStatus.DONE_REMOVED else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] else: diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 0bce90aac8faf1df4758755a8e2d82d6e20dd133..9c4835376b24afe2121e844d436984be93b3f63f 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -1,15 +1,18 @@ """Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" import warnings -from typing import Callable, Tuple, Optional, Dict, List, Any +from typing import Callable, Tuple, Optional, Dict, List import msgpack import numpy as np -from flatland.core.grid.grid4_utils import get_direction, mirror -from flatland.core.grid.grid_utils import distance_on_rail +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \ + Vec2dOperations from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes +from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -123,7 +126,9 @@ def complex_rail_generator(nr_start_goal=1, # we might as well give up at this point break - new_path = connect_rail(rail_trans, grid_map, start, goal) + new_path = connect_rail_in_grid_map(grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, + flip_start_node_trans=True, flip_end_node_trans=True, + respect_transition_validity=True, forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 start_goal.append([start, goal]) @@ -148,7 +153,10 @@ def complex_rail_generator(nr_start_goal=1, break if not all_ok: break - new_path = connect_rail(rail_trans, grid_map, start, goal) + new_path = connect_rail_in_grid_map(grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, + flip_start_node_trans=True, flip_end_node_trans=True, + respect_transition_validity=True, forbidden_cells=None) + if len(new_path) >= 2: nr_created += 1 @@ -532,307 +540,385 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener return generator -def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2, - num_neighb=3, grid_mode=False, enhance_intersection=False, seed=0) -> RailGenerator: +def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4, + max_rails_in_city: int = 4, seed: int = 0) -> RailGenerator: """ - This is a level generator which generates complex sparse rail configurations - - :param num_cities: Number of city node (can hold trainstations) - :type num_cities: int - :param num_intersections: Number of intersection that city nodes can connect to - :param num_trainstations: Total number of trainstations in env - :param min_node_dist: Minimal distance between nodes - :param node_radius: Proximity of trainstations to center of city node - :param num_neighb: Number of neighbouring nodes each node connects to - :param grid_mode: True -> NOdes evenly distirbuted in env, False-> Random distribution of nodes - :param enhance_intersection: True -> Extra rail elements added at intersections - :param seed: Random Seed - :return: numpy.ndarray of type numpy.uint16 -- The matrix with the correct 16-bit bitmaps for each cell. + Generates railway networks with cities and inner city rails + :param max_num_cities: Number of city centers in the map + :param grid_mode: arrange cities in a grid or randomly + :param max_rails_between_cities: Maximum number of connecting rails going out from a city + :param max_rails_in_city: maximum number of internal rails + :param seed: Random seed to initiate rail + :return: generator """ - def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct: - - if num_agents > num_trainstations: - num_agents = num_trainstations - warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: + np.random.seed(seed + num_resets) rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) - rail_array = grid_map.grid - rail_array.fill(0) - np.random.seed(seed + num_resets) + cell_vector_field = np.zeros(shape=(height, width), dtype=int) - 1 + city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 1 - # Generate a set of nodes for the sparse network - # Try to connect cities to nodes first - city_positions = [] - intersection_positions = [] + min_nr_rails_in_city = 3 + rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city + rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities - # Evenly distribute cities and intersections - node_positions: List[Any] = None - nb_nodes = num_cities + num_intersections + # Evenly distribute cities if grid_mode: - nodes_ratio = height / width - nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) - nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) - x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) - y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False) - - node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, - nb_nodes, - nodes_per_row, x_positions, - y_positions) - - - + city_positions, city_cells = _generate_evenly_distr_city_positions(max_num_cities, city_radius, width, height) else: - - node_positions = _generate_node_positions_not_grid_mode(city_positions, height, - intersection_positions, - nb_nodes, width) - - # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode - nb_nodes = len(node_positions) - _num_cities = len(city_positions) - _num_intersections = len(intersection_positions) - - # Chose node connection - # Set up list of available nodes to connect to - available_nodes_full = np.arange(nb_nodes) - available_cities = np.arange(_num_cities) - available_intersections = np.arange(_num_cities, nb_nodes) - - # Start at some node - current_node = np.random.randint(len(available_nodes_full)) - node_stack = [current_node] - allowed_connections = num_neighb - first_node = True - while len(node_stack) > 0: - current_node = node_stack[0] - delete_idx = np.where(available_nodes_full == current_node) - available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) - - # Priority city to intersection connections - if current_node < _num_cities and len(available_intersections) > 0: - available_nodes = available_intersections - delete_idx = np.where(available_cities == current_node) - available_cities = np.delete(available_cities, delete_idx, 0) - - # Priority intersection to city connections - elif current_node >= _num_cities and len(available_cities) > 0: - available_nodes = available_cities - delete_idx = np.where(available_intersections == current_node) - available_intersections = np.delete(available_intersections, delete_idx, 0) - - # If no options possible connect to whatever node is still available - else: - available_nodes = available_nodes_full - - # Sort available neighbors according to their distance. - node_dist = [] - for av_node in available_nodes: - node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node])) - available_nodes = available_nodes[np.argsort(node_dist)] - - # Set number of neighboring nodes - if len(available_nodes) >= allowed_connections: - connected_neighb_idx = available_nodes[:allowed_connections] - else: - connected_neighb_idx = available_nodes - - # Less connections for subsequent nodes - if first_node: - allowed_connections -= 1 - first_node = False - - # Connect to the neighboring nodes - for neighb in connected_neighb_idx: - if neighb not in node_stack: - node_stack.append(neighb) - connect_nodes(rail_trans, grid_map, node_positions[current_node], node_positions[neighb]) - node_stack.pop(0) - - # Place train stations close to the node - # We currently place them uniformly distributed among all cities - built_num_trainstation = 0 - train_stations = [[] for i in range(_num_cities)] - - if _num_cities > 1: - - for station in range(num_trainstations): - spot_found = True - trainstation_node = int(station / num_trainstations * _num_cities) - - station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), - 0, - height - 1) - station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), - 0, - width - 1) - tries = 0 - while (station_x, station_y) in train_stations[trainstation_node] \ - or (station_x, station_y) == node_positions[trainstation_node] \ - or rail_array[(station_x, station_y)] != 0: # noqa: E125 - - station_x = np.clip( - node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), - 0, - height - 1) - station_y = np.clip( - node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), - 0, - width - 1) - tries += 1 - if tries > 100: - warnings.warn("Could not set trainstations, please change initial parameters!!!!") - spot_found = False - break - - if spot_found: - train_stations[trainstation_node].append((station_x, station_y)) - - # Connect train station to the correct node - connection = connect_from_nodes(rail_trans, grid_map, node_positions[trainstation_node], - (station_x, station_y)) - # Check if connection was made - if len(connection) == 0: - if len(train_stations[trainstation_node]) > 0: - train_stations[trainstation_node].pop(-1) - else: - built_num_trainstation += 1 - - # Adjust the number of agents if you could not build enough trainstations - if num_agents > built_num_trainstation: - num_agents = built_num_trainstation - warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") - - # Place passing lanes at intersections - # We currently place them uniformly distirbuted among all cities - if enhance_intersection: - - for intersection in range(_num_intersections): - intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3), - 1, - height - 2) - intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3), - 2, - width - 2) - intersect_x_2 = np.clip( - intersection_positions[intersection][0] + np.random.randint(-3, -1), - 1, - height - 2) - intersect_y_2 = np.clip( - intersection_positions[intersection][1] + np.random.randint(-3, 3), - 1, - width - 2) - - # Connect train station to the correct node - connect_nodes(rail_trans, grid_map, (intersect_x_1, intersect_y_1), - (intersect_x_2, intersect_y_2)) - connect_nodes(rail_trans, grid_map, intersection_positions[intersection], - (intersect_x_1, intersect_y_1)) - connect_nodes(rail_trans, grid_map, intersection_positions[intersection], - (intersect_x_2, intersect_y_2)) - grid_map.fix_transitions((intersect_x_1, intersect_y_1)) - grid_map.fix_transitions((intersect_x_2, intersect_y_2)) - - # Fix all nodes with illegal transition maps - for current_node in node_positions: - grid_map.fix_transitions(current_node) - - # Generate start and target node directory for all agents. - # Assure that start and target are not in the same node - agent_start_targets_nodes = [] - - # Slot availability in node - node_available_start = [] - node_available_target = [] - for node_idx in range(_num_cities): - node_available_start.append(len(train_stations[node_idx])) - node_available_target.append(len(train_stations[node_idx])) - - # Assign agents to slots - for agent_idx in range(num_agents): - avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] - avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] - start_node = np.random.choice(avail_start_nodes) - target_node = np.random.choice(avail_target_nodes) - tries = 0 - found_agent_pair = True - while target_node == start_node: - target_node = np.random.choice(avail_target_nodes) - tries += 1 - # Test again with new start node if no pair is found (This code needs to be improved) - if (tries + 1) % 10 == 0: - start_node = np.random.choice(avail_start_nodes) - if tries > 100: - warnings.warn("Could not set trainstations, removing agent!") - found_agent_pair = False - break - if found_agent_pair: - node_available_start[start_node] -= 1 - node_available_target[target_node] -= 1 - agent_start_targets_nodes.append((start_node, target_node)) - else: - num_agents -= 1 + city_positions, city_cells = _generate_random_city_positions(max_num_cities, city_radius, width, height) + + # reduce num_cities if less were generated in random mode + num_cities = len(city_positions) + + # Set up connection points for all cities + inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_city_connection_points( + city_positions, city_radius, rails_between_cities, + rails_in_city) + + # Connect the cities through the connection points + inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells, + rail_trans, grid_map) + # Build inner cities + through_tracks, free_rails = _build_inner_cities(city_positions, inner_connection_points, + outer_connection_points, + rail_trans, + grid_map) + # Populate cities + train_stations = _set_trainstation_positions(city_positions, city_radius, free_rails) + + # Fix all transition elements + _fix_transitions(city_cells, inter_city_lines, grid_map) + + # Generate start target pairs + agent_start_targets_cities = _generate_start_target_pairs(num_agents, num_cities, train_stations, + city_orientations) return grid_map, {'agents_hints': { 'num_agents': num_agents, - 'agent_start_targets_nodes': agent_start_targets_nodes, - 'train_stations': train_stations + 'agent_start_targets_cities': agent_start_targets_cities, + 'train_stations': train_stations, + 'city_orientations': city_orientations }} - def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, - width): - - node_positions = [] - for node_idx in range(nb_nodes): - to_close = True + def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (IntVector2DArray, IntVector2DArray): + city_positions: IntVector2DArray = [] + city_cells: IntVector2DArray = [] + for city_idx in range(num_cities): + too_close = True tries = 0 - while to_close: - x_tmp = node_radius + np.random.randint(height - node_radius) - y_tmp = node_radius + np.random.randint(width - node_radius) - to_close = False - + while too_close: + row = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1)) + col = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) + too_close = False # Check distance to cities - for node_pos in city_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - - # Check distance to intersections - for node_pos in intersection_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - - if not to_close: - node_positions.append((x_tmp, y_tmp)) - if node_idx < num_cities: - city_positions.append((x_tmp, y_tmp)) - else: - intersection_positions.append((x_tmp, y_tmp)) + for city_pos in city_positions: + if _are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1): + too_close = True + + if not too_close: + city_positions.append((row, col)) + city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius)) + tries += 1 - if tries > 100: + if tries > 200: warnings.warn( - "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format( - len(node_positions), - tries, nb_nodes)) + "Could not only set {} cities after {} tries, although {} of cities required to be generated!".format( + len(city_positions), + tries, num_cities)) + break + return city_positions, city_cells + + def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (IntVector2DArray, IntVector2DArray): + aspect_ratio = height / width + cities_per_row = int(np.ceil(np.sqrt(num_cities * aspect_ratio))) + cities_per_col = int(np.ceil(num_cities / cities_per_row)) + row_positions = np.linspace(city_radius + 1, height - city_radius - 2, cities_per_row, dtype=int) + col_positions = np.linspace(city_radius + 1, width - city_radius - 2, cities_per_col, dtype=int) + city_positions = [] + city_cells = [] + for city_idx in range(num_cities): + row = row_positions[city_idx % cities_per_row] + col = col_positions[city_idx // cities_per_row] + city_positions.append((row, col)) + city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius)) + return city_positions, city_cells + + def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, rails_between_cities: int, + rails_in_city: int = 2) -> (List[List[List[IntVector2D]]], + List[List[List[IntVector2D]]], + List[np.ndarray], + List[Grid4TransitionsEnum]): + inner_connection_points: List[List[List[IntVector2D]]] = [] + outer_connection_points: List[List[List[IntVector2D]]] = [] + connection_info: List[np.ndarray] = [] + city_orientations: List[Grid4TransitionsEnum] = [] + for city_position in city_positions: + + # Chose the directions where close cities are situated + neighb_dist = [] + for neighbour_city in city_positions: + neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_position, neighbour_city)) + closest_neighb_idx = argsort(neighb_dist) + + # Store the directions to these neighbours and orient city to face closest neighbour + connection_sides_idx = [] + idx = 1 + if grid_mode: + current_closest_direction = np.random.randint(4) + else: + current_closest_direction = direction_to_point(city_position, city_positions[closest_neighb_idx[idx]]) + connection_sides_idx.append(current_closest_direction) + connection_sides_idx.append((current_closest_direction + 2) % 4) + city_orientations.append(current_closest_direction) + # set the number of tracks within a city, at least 2 tracks per city + connections_per_direction = np.zeros(4, dtype=int) + nr_of_connection_points = np.random.randint(3, rails_in_city + 1) + for idx in connection_sides_idx: + connections_per_direction[idx] = nr_of_connection_points + connection_points_coordinates_inner: List[List[IntVector2D]] = [[] for i in range(4)] + connection_points_coordinates_outer: List[List[IntVector2D]] = [[] for i in range(4)] + number_of_out_rails = np.random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1) + start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) + for direction in range(4): + connection_slots = np.arange(connections_per_direction[direction]) - int( + connections_per_direction[direction] / 2) + for connection_idx in range(connections_per_direction[direction]): + if direction == 0: + tmp_coordinates = ( + city_position[0] - city_radius, city_position[1] + connection_slots[connection_idx]) + if direction == 1: + tmp_coordinates = ( + city_position[0] + connection_slots[connection_idx], city_position[1] + city_radius) + if direction == 2: + tmp_coordinates = ( + city_position[0] + city_radius, city_position[1] + connection_slots[connection_idx]) + if direction == 3: + tmp_coordinates = ( + city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius) + connection_points_coordinates_inner[direction].append(tmp_coordinates) + if connection_idx in range(start_idx, start_idx + number_of_out_rails + 1): + connection_points_coordinates_outer[direction].append(tmp_coordinates) + + inner_connection_points.append(connection_points_coordinates_inner) + outer_connection_points.append(connection_points_coordinates_outer) + connection_info.append(connections_per_direction) + return inner_connection_points, outer_connection_points, connection_info, city_orientations + + def _connect_cities(city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]], city_cells: IntVector2DArray, + rail_trans: RailEnvTransitions, grid_map: GridTransitionMap) -> List[IntVector2DArray]: + """ + Function to connect the different cities through their connection points + :param city_positions: Positions of city centers + :param connection_points: Boarder connection points of cities + :param rail_trans: Transitions + :param grid_map: Grid map + :return: + """ + all_paths: List[IntVector2DArray] = [] + + grid4_directions = [Grid4TransitionsEnum.NORTH, Grid4TransitionsEnum.EAST, Grid4TransitionsEnum.SOUTH, + Grid4TransitionsEnum.WEST] + + for current_city_idx in np.arange(len(city_positions)): + closest_neighbours = _closest_neighbour_in_grid4_directions(current_city_idx, city_positions) + for out_direction in grid4_directions: + + neighbour_idx = get_closest_neighbour_for_direction(closest_neighbours, out_direction) + + for city_out_connection_point in connection_points[current_city_idx][out_direction]: + + min_connection_dist = np.inf + for direction in grid4_directions: + current_points = connection_points[neighbour_idx][direction] + for tmp_in_connection_point in current_points: + tmp_dist = Vec2dOperations.get_manhattan_distance(city_out_connection_point, + tmp_in_connection_point) + if tmp_dist < min_connection_dist: + min_connection_dist = tmp_dist + neighbour_connection_point = tmp_in_connection_point + + new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point, + rail_trans, flip_start_node_trans=False, + flip_end_node_trans=False, respect_transition_validity=False, + forbidden_cells=city_cells) + all_paths.extend(new_line) + + return all_paths + + def get_closest_neighbour_for_direction(closest_neighbours, out_direction): + neighbour_idx = closest_neighbours[out_direction] + if neighbour_idx is not None: + return neighbour_idx + + neighbour_idx = closest_neighbours[(out_direction - 1) % 4] # counter-clockwise + if neighbour_idx is not None: + return neighbour_idx + + neighbour_idx = closest_neighbours[(out_direction + 1) % 4] # clockwise + if neighbour_idx is not None: + return neighbour_idx + + return closest_neighbours[(out_direction + 2) % 4] # clockwise + + def _build_inner_cities(city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]], + outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap) -> (List[IntVector2DArray], List[List[List[IntVector2D]]]): + """ + Builds inner city tracks. This current version connects all incoming connections to all outgoing connections + :param city_positions: Positions of the cities + :param inner_connection_points: Points on city boarder that are used to generate inner city track + :param outer_connection_points: Points where the city is connected to neighboring cities + :param rail_trans: + :param grid_map: + :return: Returns the cells of the through path which cannot be occupied by trainstations + """ + through_path_cells: List[IntVector2DArray] = [[] for i in range(len(city_positions))] + free_rails: List[List[List[IntVector2D]]] = [[] for i in range(len(city_positions))] + for current_city in range(len(city_positions)): + all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in + sublist] + # This part only works if we have keep same number of connection points for both directions + # Also only works with two connection direction at each city + for i in range(4): + if len(inner_connection_points[current_city][i]) > 0: + boarder = i break - node_positions = city_positions + intersection_positions - return node_positions - - def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes, - nodes_per_row, x_positions, y_positions): - for node_idx in range(nb_nodes): + opposite_boarder = (boarder + 2) % 4 + boarder_one = inner_connection_points[current_city][boarder] + boarder_two = inner_connection_points[current_city][opposite_boarder] + + # Connect the ends of the tracks + connect_straight_line_in_grid_map(grid_map, boarder_one[0], boarder_one[-1], rail_trans) + connect_straight_line_in_grid_map(grid_map, boarder_two[0], boarder_two[-1], rail_trans) + + # Connect parallel tracks + for track_id in range(len(inner_connection_points[current_city][boarder])): + source = inner_connection_points[current_city][boarder][track_id] + target = inner_connection_points[current_city][opposite_boarder][track_id] + current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans) + if target in all_outer_connection_points and source in all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) + else: + free_rails[current_city].append(current_track) + return through_path_cells, free_rails + + def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int, + free_rails: List[List[List[IntVector2D]]]) -> List[List[Tuple[IntVector2D, int]]]: + num_cities = len(city_positions) + train_stations = [[] for i in range(num_cities)] + for current_city in range(len(city_positions)): + for track_nbr in range(len(free_rails[current_city])): + possible_location = free_rails[current_city][track_nbr][city_radius] + train_stations[current_city].append((possible_location, track_nbr)) + return train_stations + + def _generate_start_target_pairs(num_agents: int, num_cities: int, + train_stations: List[List[Tuple[IntVector2D, int]]], + city_orientation: List[Grid4TransitionsEnum]) -> List[Tuple[int, int, + Grid4TransitionsEnum]]: + # Generate start and target city directory for all agents. + # Assure that start and target are not in the same city + agent_start_targets_cities = [] + + # Slot availability in city + city_available_start = [] + city_available_target = [] + for city_idx in range(num_cities): + city_available_start.append(len(train_stations[city_idx])) + city_available_target.append(len(train_stations[city_idx])) - x_tmp = x_positions[node_idx % nodes_per_row] - y_tmp = y_positions[node_idx // nodes_per_row] - if node_idx in city_idx: - city_positions.append((x_tmp, y_tmp)) - else: - intersection_positions.append((x_tmp, y_tmp)) - node_positions = city_positions + intersection_positions - return node_positions + # Assign agents to slots + for agent_idx in range(num_agents): + avail_start_cities = [idx for idx, val in enumerate(city_available_start) if val > 0] + avail_target_cities = [idx for idx, val in enumerate(city_available_target) if val > 0] + # Set probability to choose start and stop from trainstations + sum_start = sum(np.array(city_available_start)[avail_start_cities]) + sum_target = sum(np.array(city_available_target)[avail_target_cities]) + p_avail_start = [float(i) / sum_start for i in np.array(city_available_start)[avail_start_cities]] + + start_target_tuple = np.random.choice(avail_start_cities, p=p_avail_start, size=2, replace=False) + start_city = start_target_tuple[0] + target_city = start_target_tuple[1] + agent_start_targets_cities.append((start_city, target_city, city_orientation[start_city])) + return agent_start_targets_cities + + def _fix_transitions(city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray], + grid_map: GridTransitionMap): + """ + Function to fix all transition elements in environment + """ + # Fix all cities with illegal transition maps + rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int') + rails_to_fix_cnt = 0 + cells_to_fix = city_cells + inter_city_lines + for cell in cells_to_fix: + cell_valid = grid_map.cell_neighbours_valid(cell, True) + if grid_map.grid[cell] == int('1000010000100001', 2): + grid_map.fix_transitions(cell) + if not cell_valid: + rails_to_fix[2 * rails_to_fix_cnt] = cell[0] + rails_to_fix[2 * rails_to_fix_cnt + 1] = cell[1] + rails_to_fix_cnt += 1 + + # Fix all other cells + for cell in range(rails_to_fix_cnt): + grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]), ) + + def _closest_neighbour_in_grid4_directions(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]: + """ + Returns indices of closest neighbour in every direction NESW + :param current_city_idx: Index of city in city_positions list + :param city_positions: list of all points being considered + :return: list of index of closest neighbour in all directions + """ + city_distances = [] + closest_neighbour: List[int] = [None for i in range(4)] + + # compute distance to all other cities + for city_idx in range(len(city_positions)): + city_distances.append(Vec2dOperations.get_manhattan_distance(city_positions[current_city_idx], city_positions[city_idx])) + sorted_neighbours = np.argsort(city_distances) + + for neighbour in sorted_neighbours[1:]: # do not include city itself + direction_to_neighbour = direction_to_point(city_positions[current_city_idx], city_positions[neighbour]) + if closest_neighbour[direction_to_neighbour] is None: + closest_neighbour[direction_to_neighbour] = neighbour + + # early return once all 4 directions have a closest neighbour + if None not in closest_neighbour: + return closest_neighbour + + return closest_neighbour + + def argsort(seq): + # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python + return sorted(range(len(seq)), key=seq.__getitem__) + + def _get_cells_in_city(center: IntVector2D, radius: int) -> IntVector2DArray: + """ + + Parameters + ---------- + center center coordinates of city + radius radius of city (it is a square) + + Returns + ------- + flat list of all cell coordinates in the city + + """ + x_range = np.arange(center[0] - radius, center[0] + radius + 1) + y_range = np.arange(center[1] - radius, center[1] + radius + 1) + x_values = np.repeat(x_range, len(y_range)) + y_values = np.tile(y_range, len(x_range)) + return list(zip(x_values, y_values)) + + def _are_cities_overlapping(center_1, center_2, radius): + return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius return generator diff --git a/flatland/envs/rail_generators_city_generator.py b/flatland/envs/rail_generators_city_generator.py deleted file mode 100644 index ecea9f902d509572afd9087a6e9b64bee144b3f2..0000000000000000000000000000000000000000 --- a/flatland/envs/rail_generators_city_generator.py +++ /dev/null @@ -1,499 +0,0 @@ -import copy -import warnings -from typing import Sequence, Optional - -import numpy as np - -from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2DDistance, IntVector2DArrayArray -from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d -from flatland.core.grid.rail_env_grid import RailEnvTransitions -from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail -from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct - -FloatArrayType = Sequence[float] - - -def city_generator(num_cities: int = 5, - city_size: int = 10, - allowed_rotation_angles: Optional[Sequence[float]] = None, - max_number_of_station_tracks: int = 4, - nbr_of_switches_per_station_track: int = 2, - connect_max_nbr_of_shortes_city: int = 4, - do_random_connect_stations: bool = False, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, - seed: int = 0, - print_out_info: bool = True) -> RailGenerator: - """ - This is a level generator which generates a realistic rail configurations - - :param num_cities: Number of city node - :param city_size: Length of city measure in cells - :param allowed_rotation_angles: Rotate the city (around center) - :param max_number_of_station_tracks: max number of tracks per station - :param nbr_of_switches_per_station_track: number of switches per track (max) - :param connect_max_nbr_of_shortes_city: max number of connecting track between stations - :param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand - :param a_star_distance_function: Heuristic how the distance between two nodes get estimated in the "a-star" path - :param seed: Random Seed - :param print_out_info: print debug info if True - - :return: The matrix with the correct 16-bit bitmaps for each cell. - :rtype: numpy.ndarray of type numpy.uint16 - - """ - - def do_generate_city_locations(width: int, - height: int, - intern_city_size: int, - intern_max_number_of_station_tracks: int) -> (IntVector2DArray, int): - - X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) - Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) - - max_num_cities = min(num_cities, X * Y) - - cities_at = np.random.choice(X * Y, max_num_cities, False) - cities_at = np.sort(cities_at) - if print_out_info: - print("max nbr of cities with given configuration is:", max_num_cities) - - x = np.floor(cities_at / Y) - y = cities_at - x * Y - xs = (x * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2 - ys = (y * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2 - - generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))] - return generate_city_locations, max_num_cities - - def do_orient_cities(generate_city_locations: IntVector2DArrayArray, intern_city_size: int, - rotation_angles_set: FloatArrayType): - for i in range(len(generate_city_locations)): - # station main orientation (horizontal or vertical - rot_angle = np.random.choice(rotation_angles_set) - add_pos_val = Vec2d.scale(Vec2d.rotate((1, 0), rot_angle), - int(max(1.0, (intern_city_size - 3) / 2))) - # noinspection PyTypeChecker - generate_city_locations[i][0] = Vec2d.add(generate_city_locations[i][1], add_pos_val) - add_pos_val = Vec2d.scale(Vec2d.rotate((1, 0), 180 + rot_angle), - int(max(1.0, (intern_city_size - 3) / 2))) - # noinspection PyTypeChecker - generate_city_locations[i][1] = Vec2d.add(generate_city_locations[i][1], add_pos_val) - return generate_city_locations - - # noinspection PyTypeChecker - def create_stations_from_city_locations(rail_trans: RailEnvTransitions, - grid_map: GridTransitionMap, - generate_city_locations: IntVector2DArrayArray, - intern_max_number_of_station_tracks: int) -> (IntVector2DArray, - IntVector2DArray, - IntVector2DArray, - IntVector2DArray, - IntVector2DArray): - - nodes_added = [] - start_nodes_added: IntVector2DArrayArray = [[] for _ in range(len(generate_city_locations))] - end_nodes_added: IntVector2DArrayArray = [[] for _ in range(len(generate_city_locations))] - station_slots = [[] for _ in range(len(generate_city_locations))] - station_tracks = [[[] for _ in range(intern_max_number_of_station_tracks)] for _ in range(len( - generate_city_locations))] - - station_slots_cnt = 0 - - for city_loop in range(len(generate_city_locations)): - # Connect train station to the correct node - number_of_connecting_tracks = np.random.choice(max(0, intern_max_number_of_station_tracks)) + 1 - track_id = 0 - for ct in range(number_of_connecting_tracks): - org_start_node = generate_city_locations[city_loop][0] - org_end_node = generate_city_locations[city_loop][1] - - ortho_trans = Vec2d.make_orthogonal( - Vec2d.normalize(Vec2d.subtract(org_start_node, org_end_node))) - s = (ct - number_of_connecting_tracks / 2.0) - start_node = Vec2d.ceil( - Vec2d.add(org_start_node, Vec2d.scale(ortho_trans, s))) - end_node = Vec2d.ceil( - Vec2d.add(org_end_node, Vec2d.scale(ortho_trans, s))) - - connection = connect_from_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function) - if len(connection) > 0: - nodes_added.append(start_node) - nodes_added.append(end_node) - - start_nodes_added[city_loop].append(start_node) - end_nodes_added[city_loop].append(end_node) - - # place in the center of path a station slot - # station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) - for c_loop in range(len(connection)): - station_slots[city_loop].append(connection[c_loop]) - station_slots_cnt += len(connection) - - station_tracks[city_loop][track_id] = connection - track_id += 1 - else: - if print_out_info: - print("create_stations_from_city_locations : connect_from_nodes -> no path found") - - if print_out_info: - print("max nbr of station slots with given configuration is:", station_slots_cnt) - - return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks - - # noinspection PyTypeChecker - def create_switches_at_stations(rail_trans: RailEnvTransitions, - grid_map: GridTransitionMap, - station_tracks: IntVector2DArrayArray, - nodes_added: IntVector2DArray, - intern_nbr_of_switches_per_station_track: int) -> IntVector2DArray: - - for k_loop in range(intern_nbr_of_switches_per_station_track): - for city_loop in range(len(station_tracks)): - k = k_loop + city_loop - datas = station_tracks[city_loop] - if len(datas) > 1: - - track = datas[0] - if len(track) > 0: - if k % 2 == 0: - x = int(np.random.choice(int(len(track) / 2)) + 1) - else: - x = len(track) - int(np.random.choice(int(len(track) / 2)) + 1) - start_node = track[x] - for i in np.arange(1, len(datas)): - track = datas[i] - if len(track) > 1: - if k % 2 == 0: - x = x + 2 - if len(track) <= x: - x = 1 - else: - x = x - 2 - if x < 2: - x = len(track) - 1 - end_node = track[x] - connection = connect_rail(rail_trans, grid_map, start_node, end_node, - a_star_distance_function) - if len(connection) == 0: - if print_out_info: - print("create_switches_at_stations : connect_rail -> no path found") - start_node = datas[i][0] - end_node = datas[i - 1][0] - connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function) - - nodes_added.append(start_node) - nodes_added.append(end_node) - - if k % 2 == 0: - x = x + 2 - if len(track) <= x: - x = 1 - else: - x = x - 2 - if x < 2: - x = len(track) - 2 - start_node = track[x] - - return nodes_added - - def create_graph_edge(from_city_index: int, to_city_index: int) -> (int, int, int): - return from_city_index, to_city_index, np.inf - - def calc_nbr_of_graphs(graph: []) -> ([], []): - for i in range(len(graph)): - for j in range(len(graph)): - a = graph[i] - b = graph[j] - connected = False - if a[0] == b[0] or a[1] == b[0]: - connected = True - if a[0] == b[1] or a[1] == b[1]: - connected = True - - if connected: - a = [graph[i][0], graph[i][1], graph[i][2]] - b = [graph[j][0], graph[j][1], graph[j][2]] - graph[i] = (graph[i][0], graph[i][1], min(np.min(a), np.min(b))) - graph[j] = (graph[j][0], graph[j][1], min(np.min(a), np.min(b))) - else: - a = [graph[i][0], graph[i][1], graph[i][2]] - graph[i] = (graph[i][0], graph[i][1], np.min(a)) - b = [graph[j][0], graph[j][1], graph[j][2]] - graph[j] = (graph[j][0], graph[j][1], np.min(b)) - - graph_ids = [] - for i in range(len(graph)): - graph_ids.append(graph[i][2]) - if print_out_info: - print("************* NBR of graphs:", len(np.unique(graph_ids))) - return graph, np.unique(graph_ids).astype(int) - - def connect_sub_graphs(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - org_s_nodes: IntVector2DArrayArray, - org_e_nodes: IntVector2DArrayArray, - city_edges: IntVector2DArray, - nodes_added: IntVector2DArray): - _, graphids = calc_nbr_of_graphs(city_edges) - if len(graphids) > 0: - for i in range(len(graphids) - 1): - connection = [] - iteration_counter = 0 - while len(connection) == 0 and iteration_counter < 100: - s_nodes = copy.deepcopy(org_s_nodes) - e_nodes = copy.deepcopy(org_e_nodes) - start_nodes = s_nodes[graphids[i]] - end_nodes = e_nodes[graphids[i + 1]] - start_node = start_nodes[np.random.choice(len(start_nodes))] - end_node = end_nodes[np.random.choice(len(end_nodes))] - # TODO : removing, what the hell is going on, why we have to set rail_array -> transition to zero - # TODO : before we can call connect_rail. If we don't reset the transistion to zero -> no rail - # TODO : will be generated. - grid_map.grid[start_node] = 0 - grid_map.grid[end_node] = 0 - connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function) - if len(connection) > 0: - nodes_added.append(start_node) - nodes_added.append(end_node) - else: - if print_out_info: - print("connect_sub_graphs : connect_rail -> no path found") - - iteration_counter += 1 - - def connect_stations(rail_trans: RailEnvTransitions, - grid_map: GridTransitionMap, - org_s_nodes: IntVector2DArrayArray, - org_e_nodes: IntVector2DArrayArray, - nodes_added: IntVector2DArray, - intern_connect_max_nbr_of_shortes_city: int): - city_edges = [] - - s_nodes: IntVector2DArrayArray = copy.deepcopy(org_s_nodes) - e_nodes: IntVector2DArrayArray = copy.deepcopy(org_e_nodes) - - for nbr_connected in range(intern_connect_max_nbr_of_shortes_city): - for city_loop in range(len(s_nodes)): - sns = s_nodes[city_loop] - for start_node in sns: - min_distance = np.inf - end_node = None - cl = 0 - for city_loop_find_shortest in range(len(e_nodes)): - if city_loop_find_shortest == city_loop: - continue - ens = e_nodes[city_loop_find_shortest] - for en in ens: - d = Vec2d.get_euclidean_distance(start_node, en) - if d < min_distance: - min_distance = d - end_node = en - cl = city_loop_find_shortest - - if end_node is not None: - tmp_trans_sn = grid_map.grid[start_node] - tmp_trans_en = grid_map.grid[end_node] - grid_map.grid[start_node] = 0 - grid_map.grid[end_node] = 0 - connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function) - if len(connection) > 0: - s_nodes[city_loop].remove(start_node) - e_nodes[cl].remove(end_node) - - edge = create_graph_edge(city_loop, cl) - if city_loop > cl: - edge = create_graph_edge(cl, city_loop) - if not (edge in city_edges): - city_edges.append(edge) - nodes_added.append(start_node) - nodes_added.append(end_node) - else: - if print_out_info: - print("connect_stations : connect_rail -> no path found") - - grid_map.grid[start_node] = tmp_trans_sn - grid_map.grid[end_node] = tmp_trans_en - - connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added) - - def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start_nodes_added: IntVector2DArray, - end_nodes_added: IntVector2DArray, - nodes_added: IntVector2DArray, - intern_connect_max_nbr_of_shortes_city: int): - if len(start_nodes_added) < 1: - return - x = np.arange(len(start_nodes_added)) - random_city_idx = np.random.choice(x, len(x), False) - - # cyclic connection - random_city_idx = np.append(random_city_idx, random_city_idx[0]) - - for city_loop in range(len(random_city_idx) - 1): - idx_a = random_city_idx[city_loop + 1] - idx_b = random_city_idx[city_loop] - s_nodes = start_nodes_added[idx_a] - e_nodes = end_nodes_added[idx_b] - - max_input_output = max(len(s_nodes), len(e_nodes)) - max_input_output = min(intern_connect_max_nbr_of_shortes_city, max_input_output) - - idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False) - idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False) - - if len(idx_s_nodes) < max_input_output: - idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len( - idx_s_nodes))) - if len(idx_e_nodes) < max_input_output: - idx_e_nodes = np.append(idx_e_nodes, - np.random.choice(np.arange(len(idx_e_nodes)), max_input_output - len( - idx_e_nodes))) - - if len(idx_s_nodes) > intern_connect_max_nbr_of_shortes_city: - idx_s_nodes = np.random.choice(idx_s_nodes, intern_connect_max_nbr_of_shortes_city, False) - if len(idx_e_nodes) > intern_connect_max_nbr_of_shortes_city: - idx_e_nodes = np.random.choice(idx_e_nodes, intern_connect_max_nbr_of_shortes_city, False) - - for i in range(max_input_output): - start_node = s_nodes[idx_s_nodes[i]] - end_node = e_nodes[idx_e_nodes[i]] - grid_map.grid[start_node] = 0 - grid_map.grid[end_node] = 0 - connection = connect_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function) - if len(connection) > 0: - nodes_added.append(start_node) - nodes_added.append(end_node) - else: - if print_out_info: - print("connect_random_stations : connect_nodes -> no path found") - - def remove_switch_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - train_stations: IntVector2DArray): - tmp_train_stations = copy.deepcopy(train_stations) - for city_loop in range(len(train_stations)): - for n in tmp_train_stations[city_loop]: - do_remove = True - trans = rail_trans.transition_list[1] - for _ in range(4): - trans = rail_trans.rotate_transition(trans, rotation=90) - if grid_map.grid[n] == trans: - do_remove = False - if do_remove: - train_stations[city_loop].remove(n) - - def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: - rail_trans = RailEnvTransitions() - grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) - grid_map.grid.fill(0) - np.random.seed(seed + num_resets) - - intern_city_size = city_size - if city_size < 3: - warnings.warn("min city_size requried to be > 3!") - intern_city_size = 3 - if print_out_info: - print("intern_city_size:", intern_city_size) - - intern_max_number_of_station_tracks = max_number_of_station_tracks - if max_number_of_station_tracks < 1: - warnings.warn("min max_number_of_station_tracks requried to be > 1!") - intern_max_number_of_station_tracks = 1 - if print_out_info: - print("intern_max_number_of_station_tracks:", intern_max_number_of_station_tracks) - - intern_nbr_of_switches_per_station_track = nbr_of_switches_per_station_track - if nbr_of_switches_per_station_track < 1: - warnings.warn("min intern_nbr_of_switches_per_station_track requried to be > 2!") - intern_nbr_of_switches_per_station_track = 2 - if print_out_info: - print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track) - - intern_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city - if connect_max_nbr_of_shortes_city < 1: - warnings.warn("min intern_connect_max_nbr_of_shortes_city requried to be > 1!") - intern_connect_max_nbr_of_shortes_city = 1 - if print_out_info: - print("intern_connect_max_nbr_of_shortes_city:", intern_connect_max_nbr_of_shortes_city) - - # ---------------------------------------------------------------------------------- - # generate city locations - generate_city_locations, max_num_cities = do_generate_city_locations(width, height, intern_city_size, - intern_max_number_of_station_tracks) - - # ---------------------------------------------------------------------------------- - # apply orientation to cities (horizontal, vertical) - generate_city_locations = do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles) - - # ---------------------------------------------------------------------------------- - # generate city topology - nodes_added, train_stations_slots, s_nodes, e_nodes, station_tracks = \ - create_stations_from_city_locations(rail_trans, grid_map, - generate_city_locations, - intern_max_number_of_station_tracks) - # build switches - create_switches_at_stations(rail_trans, grid_map, station_tracks, nodes_added, - intern_nbr_of_switches_per_station_track) - - # ---------------------------------------------------------------------------------- - # connect stations - if do_random_connect_stations: - connect_random_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, - intern_connect_max_nbr_of_shortes_city) - else: - connect_stations(rail_trans, grid_map, s_nodes, e_nodes, nodes_added, - intern_connect_max_nbr_of_shortes_city) - - # ---------------------------------------------------------------------------------- - # fix all transition at starting / ending points (mostly add a dead end, if missing) - # TODO we might have to remove the fixing stuff in the future - for i in range(len(nodes_added)): - grid_map.fix_transitions(nodes_added[i]) - - # ---------------------------------------------------------------------------------- - # remove stations where underlaying rail is a switch - remove_switch_stations(rail_trans, grid_map, train_stations_slots) - - # ---------------------------------------------------------------------------------- - # Slot availability in node - node_available_start = [] - node_available_target = [] - for node_idx in range(max_num_cities): - node_available_start.append(len(train_stations_slots[node_idx])) - node_available_target.append(len(train_stations_slots[node_idx])) - - # Assign agents to slots - agent_start_targets_nodes = [] - for agent_idx in range(num_agents): - avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] - avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] - if len(avail_target_nodes) == 0: - num_agents -= 1 - continue - start_node = np.random.choice(avail_start_nodes) - target_node = np.random.choice(avail_target_nodes) - tries = 0 - found_agent_pair = True - while target_node == start_node: - target_node = np.random.choice(avail_target_nodes) - tries += 1 - # Test again with new start node if no pair is found (This code needs to be improved) - if (tries + 1) % 10 == 0: - start_node = np.random.choice(avail_start_nodes) - if tries > 100: - warnings.warn("Could not set train_stations, removing agent!") - found_agent_pair = False - break - if found_agent_pair: - node_available_start[start_node] -= 1 - node_available_target[target_node] -= 1 - agent_start_targets_nodes.append((start_node, target_node)) - else: - num_agents -= 1 - - return grid_map, {'agents_hints': { - 'num_agents': num_agents, - 'agent_start_targets_nodes': agent_start_targets_nodes, - 'train_stations': train_stations_slots - }} - - return generator diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 5442a0af191ce31e415964224557cb18d05407f1..05c6cc0da831a571c4f0b4f328ff6f3f4703cff1 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -1,4 +1,5 @@ """Schedule generators (railway undertaking, "EVU").""" +import random import warnings from typing import Tuple, List, Callable, Mapping, Optional, Any @@ -57,10 +58,12 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] - agent_start_targets_nodes = hints['agent_start_targets_nodes'] + agent_start_targets_cities = hints['agent_start_targets_cities'] max_num_agents = hints['num_agents'] + city_orientations = hints['city_orientations'] if num_agents > max_num_agents: num_agents = max_num_agents warnings.warn("Too many agents! Changes number of agents.") @@ -70,40 +73,25 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> agents_direction = [] for agent_idx in range(num_agents): # Set target for agent - current_target_node = agent_start_targets_nodes[agent_idx][1] - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - tries = 0 - while (target[0], target[1]) in agents_target: - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - tries += 1 - if tries > 100: - warnings.warn("Could not set target position, removing an agent") - break - agents_target.append((target[0], target[1])) - - # Set start for agent - current_start_node = agent_start_targets_nodes[agent_idx][0] - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - tries = 0 - while (start[0], start[1]) in agents_position: - tries += 1 - if tries > 100: - warnings.warn("Could not set start position, please change initial parameters!!!!") - break - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - - agents_position.append((start[0], start[1])) - + start_city = agent_start_targets_cities[agent_idx][0] + target_city = agent_start_targets_cities[agent_idx][1] + start = random.choice(train_stations[start_city]) + target = random.choice(train_stations[target_city]) + while start[1] % 2 != 0: + start = random.choice(train_stations[start_city]) + while target[1] % 2 != 1: + target = random.choice(train_stations[target_city]) + + agent_orientation = (agent_start_targets_cities[agent_idx][2] + 2 * start[1]) % 4 + if not rail.check_path_exists(start[0], agent_orientation, target[0]): + agent_orientation = (agent_orientation + 2) % 4 + if not (rail.check_path_exists(start[0], agent_orientation, target[0])): + warnings.warn("Infeasible") + agents_position.append((start[0][0], start[0][1])) + agents_target.append((target[0][0], target[0][1])) + + agents_direction.append(agent_orientation) # Orient the agent correctly - for orientation in range(4): - transitions = rail.get_transitions(start[0], start[1], orientation) - if any(transitions) > 0: - agents_direction.append(orientation) - break if speed_ratio_map: speeds = speed_initialization_helper(num_agents, speed_ratio_map) @@ -248,7 +236,3 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: return generator - -# we can us the same schedule generator for city_rail_generator -# in order to be able to change this transparently in the future, we use a different name. -city_schedule_generator = sparse_schedule_generator diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index b2dac4637d1560f9d1ac5ae485d8bf8bd5c2c97e..30ad93a24931e67b8f671a03f38a7e180c6e71f3 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -65,7 +65,7 @@ class GraphicsLayer(object): def get_cmap(self, *args, **kwargs): return plt.get_cmap(*args, **kwargs) - def set_rail_at(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None): + def set_rail_at(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None, num_agents=None): """ Set the rail at cell (row, col) to have transitions binTrans. The target argument can contain the index of the agent to indicate that agent's target is at that cell, so that a station can be @@ -73,7 +73,8 @@ class GraphicsLayer(object): """ pass - def set_agent_at(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False): + def set_agent_at(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False,rail_grid=None,show_debug=False, + clear_debug_text=True): pass def set_cell_occupied(self, iAgent, row, col): diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 4dad2ca872725517ffd47308f088f098b6abe1aa..377e1ccb65956a8f70e7223ad37779f43e2ccee4 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -109,7 +109,11 @@ class PILGL(GraphicsLayer): rebuild = True if rebuild: + # rebuild background_grid to control the visualisation of buildings, trees, mountains, lakes and river self.background_grid = np.zeros(shape=(self.width, self.height)) + + + # build base distance map (distance to targets) for x in range(self.width): for y in range(self.height): distance = int(np.ceil(np.sqrt(self.width ** 2.0 + self.height ** 2.0))) @@ -357,13 +361,19 @@ class PILSVG(PILGL): ] scenery_files_d3 = [ - "Scenery-Bergwelt_A_Teil_3_rechts.svg", + "Scenery-Bergwelt_A_Teil_1_links.svg", "Scenery-Bergwelt_A_Teil_2_mitte.svg", - "Scenery-Bergwelt_A_Teil_1_links.svg" + "Scenery-Bergwelt_A_Teil_3_rechts.svg" + ] + + scenery_files_water = [ + "Scenery_Water.svg" ] img_back_ground = self.pil_from_svg_file('svg', "Background_Light_green.svg") + self.scenery_background_white = self.pil_from_svg_file('svg', "Background_white.svg") + self.scenery = [] for file in scenery_files: img = self.pil_from_svg_file('svg', file) @@ -382,6 +392,12 @@ class PILSVG(PILGL): img = Image.alpha_composite(img_back_ground, img) self.scenery_d3.append(img) + self.scenery_water = [] + for file in scenery_files_water: + img = self.pil_from_svg_file('svg', file) + img = Image.alpha_composite(img_back_ground, img) + self.scenery_water.append(img) + def load_rail(self): """ Load the rail SVG images, apply rotations, and store as PIL images. """ @@ -499,7 +515,7 @@ class PILSVG(PILGL): False)[0] self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER) - def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, + def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, num_agents=None, show_debug=True): if binary_trans in self.pil_rail: @@ -511,8 +527,12 @@ class PILSVG(PILGL): if show_debug: self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER) + city_size = 1 + if num_agents is not None: + city_size = max(1, np.log(1 + num_agents) / 2.5) + if binary_trans == 0: - if self.background_grid[col][row] <= 4: + if self.background_grid[col][row] <= 4 + np.ceil(((col * row + col) % 10) / city_size): a = int(self.background_grid[col][row]) a = a % len(self.dBuildings) if (col + row + col * row) % 13 > 11: @@ -521,12 +541,37 @@ class PILSVG(PILGL): if (col + row + col * row) % 3 == 0: a = (a + (col + row + col * row)) % len(self.dBuildings) pil_track = self.dBuildings[a] - elif (self.background_grid[col][row] > 4) or ((col ** 3 + row ** 2 + col * row) % 10 == 0): + elif (self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or ( + (col ** 3 + row ** 2 + col * row) % + 10 == 0): a = int(self.background_grid[col][row]) - 4 a2 = (a + (col + row + col * row + col ** 3 + row ** 4)) - if a2 % 17 > 11: + if a2 % 64 > 11: a = a2 - pil_track = self.scenery[a % len(self.scenery)] + a_l = a % len(self.scenery) + if a2 % 50 == 49: + pil_track = self.scenery_water[0] + else: + pil_track = self.scenery[a_l] + if rail_grid is not None: + if a2 % 11 > 3: + if a_l == len(self.scenery) - 1: + # mountain + if col > 1 and row % 7 == 1: + if rail_grid[row, col - 1] == 0: + self.draw_image_row_col(self.scenery_d2[0], (row, col - 1), + layer=PILGL.RAIL_LAYER) + pil_track = self.scenery_d2[1] + else: + if a_l == len(self.scenery) - 1: + # mountain + if col > 2 and not (row % 7 == 1): + if rail_grid[row, col - 2] == 0 and rail_grid[row, col - 1] == 0: + self.draw_image_row_col(self.scenery_d3[0], (row, col - 2), + layer=PILGL.RAIL_LAYER) + self.draw_image_row_col(self.scenery_d3[1], (row, col - 1), + layer=PILGL.RAIL_LAYER) + pil_track = self.scenery_d3[2] self.draw_image_row_col(pil_track, (row, col), layer=PILGL.RAIL_LAYER) else: @@ -590,7 +635,8 @@ class PILSVG(PILGL): for color_idx, pil_zug_3 in enumerate(pils): self.pil_zug[(in_direction_2, out_direction_2, color_idx)] = pils[color_idx] - def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected, show_debug=False): + def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected, + rail_grid=None, show_debug=False,clear_debug_text=True): delta_dir = (out_direction - in_direction) % 4 color_idx = agent_idx % self.n_agent_colors # when flipping direction at a dead end, use the "out_direction" direction. @@ -598,14 +644,34 @@ class PILSVG(PILGL): in_direction = out_direction pil_zug = self.pil_zug[(in_direction % 4, out_direction % 4, color_idx)] self.draw_image_row_col(pil_zug, (row, col), layer=PILGL.AGENT_LAYER) + if rail_grid is not None: + if rail_grid[row, col] == 0.0: + self.draw_image_row_col(self.scenery_background_white, (row, col), layer=PILGL.RAIL_LAYER) if is_selected: bg_svg = self.pil_from_svg_file("svg", "Selected_Agent.svg") self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0) self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER) - if show_debug: - self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx)) + if not clear_debug_text: + dr = 0.2 + dc = 0.2 + if in_direction == 0: + dr = 0.8 + dc = 0.0 + if in_direction == 1: + dr = 0.0 + dc = 0.8 + if in_direction == 2: + dr = 0.4 + dc = 0.8 + if in_direction == 3: + dr = 0.8 + dc = 0.4 + + self.text_rowcol((row + dr, col + dc,), str(agent_idx), layer=PILGL.SELECTED_AGENT_LAYER) + else: + self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx)) def set_cell_occupied(self, agent_idx, row, col): occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)] diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index e7b1e72679937bc5b3093cebfa58bd2e9894ba94..c238319ad17bb09d9dbaea80335685ce1283feb1 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -41,7 +41,7 @@ class RenderTool(object): def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=False, screen_width=800, screen_height=600): + show_debug=False, clear_debug_text=True,screen_width=800, screen_height=600): self.env = env self.frame_nr = 0 @@ -60,6 +60,7 @@ class RenderTool(object): self.new_rail = True self.show_debug = show_debug + self.clear_debug_text = clear_debug_text self.update_background() def reset(self): @@ -532,7 +533,8 @@ class RenderTool(object): is_selected = False self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected, - rail_grid=env.rail.grid, show_debug=self.show_debug) + rail_grid=env.rail.grid, num_agents=env.get_num_agents(), + show_debug=self.show_debug) self.gl.build_background_map(targets) @@ -558,7 +560,8 @@ class RenderTool(object): if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: self.gl.set_cell_occupied(agent_idx, *(agent.position)) self.gl.set_agent_at(agent_idx, *position, old_direction, direction, - selected_agent == agent_idx, show_debug=self.show_debug) + selected_agent == agent_idx, rail_grid=env.rail.grid, + show_debug=self.show_debug,clear_debug_text=self.clear_debug_text) else: position = agent.position direction = agent.direction @@ -570,12 +573,14 @@ class RenderTool(object): # set_agent_at uses the agent index for the color self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, - selected_agent == agent_idx, show_debug=self.show_debug) + selected_agent == agent_idx, rail_grid=env.rail.grid, + show_debug=self.show_debug,clear_debug_text=self.clear_debug_text) # set_agent_at uses the agent index for the color if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX: self.gl.set_cell_occupied(agent_idx, *(agent.position)) - self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx) + self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx, + rail_grid=env.rail.grid) if show_observations: self.render_observation(range(env.get_num_agents()), env.dev_obs_dict) diff --git a/notebooks/Scene_Editor.ipynb b/notebooks/Scene_Editor.ipynb index acf418e6ef1fc6da9123995f58df79903b712e10..5ee1c8390c0aefd73068fb1c92b1a7f44e00d093 100644 --- a/notebooks/Scene_Editor.ipynb +++ b/notebooks/Scene_Editor.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -62,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": { "scrolled": false }, @@ -70,7 +70,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "84809032a2f84b908e889f90594b3d62", + "model_id": "8e968bb6c8204596a24df0a8ad9e8440", "version_major": 2, "version_minor": 0 }, @@ -80,18 +80,6 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "load file: temp.pkl\n", - "Regenerate size 5 5\n", - "load file: temp.pkl\n", - "load file: temp.pkl\n", - "Regenerate size 5 5\n", - "load file: temp.pkl\n" - ] } ], "source": [ diff --git a/svg/Background_white.svg b/svg/Background_white.svg new file mode 100644 index 0000000000000000000000000000000000000000..5c72b28287f43d332a99b66432699d4491a44eb7 --- /dev/null +++ b/svg/Background_white.svg @@ -0,0 +1,55 @@ +<?xml version="1.0" encoding="UTF-8" standalone="no"?> +<!-- Generator: Adobe Illustrator 23.0.3, SVG Export Plug-In . SVG Version: 6.00 Build 0) --> + +<svg + xmlns:dc="http://purl.org/dc/elements/1.1/" + xmlns:cc="http://creativecommons.org/ns#" + xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" + xmlns:svg="http://www.w3.org/2000/svg" + xmlns="http://www.w3.org/2000/svg" + xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" + xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" + version="1.1" + id="Ebene_1" + x="0px" + y="0px" + viewBox="0 0 240 240" + style="enable-background:new 0 0 240 240;" + xml:space="preserve" + sodipodi:docname="Background_white.svg" + inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata + id="metadata11"><rdf:RDF><cc:Work + rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type + rdf:resource="http://purl.org/dc/dcmitype/StillImage" /></cc:Work></rdf:RDF></metadata><defs + id="defs9" /><sodipodi:namedview + pagecolor="#ffffff" + bordercolor="#666666" + borderopacity="1" + objecttolerance="10" + gridtolerance="10" + guidetolerance="10" + inkscape:pageopacity="0" + inkscape:pageshadow="2" + inkscape:window-width="1920" + inkscape:window-height="1017" + id="namedview7" + showgrid="false" + inkscape:zoom="0.98333333" + inkscape:cx="120" + inkscape:cy="120" + inkscape:window-x="-8" + inkscape:window-y="-8" + inkscape:window-maximized="1" + inkscape:current-layer="Ebene_1" /> +<style + type="text/css" + id="style2"> + .st0{fill:#DEBDA0;} +</style> +<rect + class="st0" + width="240" + height="240" + id="rect4" + style="fill:#f9f9f9" /> +</svg> \ No newline at end of file diff --git a/svg/Scenery_Water.svg b/svg/Scenery_Water.svg new file mode 100644 index 0000000000000000000000000000000000000000..27ae56d073db82857a6fec28ddd5d0f73e87c11f --- /dev/null +++ b/svg/Scenery_Water.svg @@ -0,0 +1,143 @@ +<?xml version="1.0" encoding="UTF-8" standalone="no"?> +<!-- Generator: Adobe Illustrator 23.0.3, SVG Export Plug-In . SVG Version: 6.00 Build 0) --> + +<svg + xmlns:dc="http://purl.org/dc/elements/1.1/" + xmlns:cc="http://creativecommons.org/ns#" + xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" + xmlns:svg="http://www.w3.org/2000/svg" + xmlns="http://www.w3.org/2000/svg" + xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" + xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" + version="1.1" + id="Ebene_1" + x="0px" + y="0px" + viewBox="0 0 240 240" + style="enable-background:new 0 0 240 240;" + xml:space="preserve" + sodipodi:docname="Scenery_Water.svg" + inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata + id="metadata47"><rdf:RDF><cc:Work + rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type + rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title /></cc:Work></rdf:RDF></metadata><defs + id="defs45"> + + + + + + + + + + + + + + + + +</defs><sodipodi:namedview + pagecolor="#ffffff" + bordercolor="#666666" + borderopacity="1" + objecttolerance="10" + gridtolerance="10" + guidetolerance="10" + inkscape:pageopacity="0" + inkscape:pageshadow="2" + inkscape:window-width="1920" + inkscape:window-height="1017" + id="namedview43" + showgrid="false" + inkscape:zoom="0.98333333" + inkscape:cx="-102.20339" + inkscape:cy="120" + inkscape:window-x="-8" + inkscape:window-y="-8" + inkscape:window-maximized="1" + inkscape:current-layer="Ebene_1" /> +<style + type="text/css" + id="style2"> + .st0{fill:none;} + .st1{fill:#8B5420;} + .st2{fill:#A8642A;} + .st3{fill:#5DAD61;} + .st4{fill:#418050;} +</style> +<g + id="g6"> + <rect + x="0" + class="st0" + width="240" + height="240" + id="rect4" /> +</g> + +<g + id="g980" + style="fill:#4cc1ff;fill-opacity:1;stroke-width:1.70416141" + transform="matrix(0.58536585,0,0,0.5882353,319.96693,139.86042)"><ellipse + ry="26.440678" + rx="141.86441" + cy="-48.81356" + cx="-355.42374" + id="path948" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="39.661018" + rx="77.796608" + cy="6.1016951" + cx="-228.30508" + id="path950" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="15.762712" + rx="30" + cy="7.6271186" + cx="-216.1017" + id="path952" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="21.864407" + rx="92.033897" + cy="-12.711864" + cx="-333.05084" + id="path954" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="38.135593" + rx="65.593224" + cy="-21.864407" + cx="-457.11865" + id="path956" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="19.830509" + rx="140.84746" + cy="18.813559" + cx="-349.32202" + id="path958" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="10.677966" + rx="62.542374" + cy="-36.101696" + cx="-264.91525" + id="path960" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="46.271187" + rx="77.288132" + cy="-17.79661" + cx="-253.22034" + id="path962" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="36.610168" + rx="56.440678" + cy="-28.474577" + cx="-456.10168" + id="path964" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /><ellipse + ry="20.338984" + rx="79.322037" + cy="24.40678" + cx="-446.44067" + id="path966" + style="opacity:1;fill:#4cc1ff;fill-opacity:1;stroke:none;stroke-width:1.70416141;stroke-opacity:1" /></g></svg> \ No newline at end of file diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index b4396c101d1879bf8f5f74f9a79f9a8541d45e31..859ea230b2efa1bc8ebe0aa573615d06fae6e803 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -122,3 +122,114 @@ def test_initial_status(): ) run_replay_config(env, [test_config], activate_agents=False) + +def test_status_done_remove(): + """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + remove_agents_at_target=True + ) + + set_penalties_for_replay(env) + test_config = ReplayConfig( + replay=[ + Replay( + position=None, # not entered grid yet + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.READY_TO_DEPART, + action=RailEnvActions.DO_NOTHING, + reward=0, + + ), + Replay( + position=None, # not entered grid yet before step + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.READY_TO_DEPART, + action=RailEnvActions.MOVE_LEFT, + reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty! + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.ACTIVE, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + status=RailAgentStatus.ACTIVE, + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # running at speed 0.5 + + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.step_penalty * 0.5, # running at speed 0.5 + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! + status=RailAgentStatus.ACTIVE + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # done + status=RailAgentStatus.ACTIVE + ), + Replay( + position=None, + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE_REMOVED + ), + Replay( + position=None, + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE_REMOVED + ), + Replay( + position=None, + direction=Grid4TransitionsEnum.WEST, + action=None, + reward=env.global_reward, # already done + status=RailAgentStatus.DONE_REMOVED + ) + + ], + initial_position=(3, 9), # east dead-end + initial_direction=Grid4TransitionsEnum.EAST, + target=(3, 5), + speed=0.5 + ) + + run_replay_config(env, [test_config], activate_agents=False) diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py index 72deddc66eacc71a5aa840b49225e5ba056a8b84..b99a2624bb8860c6ff6c3cf773e57d68c81ed02f 100644 --- a/tests/test_flatland_core_grid4_generators_util.py +++ b/tests/test_flatland_core_grid4_generators_util.py @@ -1,9 +1,8 @@ import numpy as np -from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes, connect_to_nodes +from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map def test_build_railway_infrastructure(): @@ -12,25 +11,37 @@ def test_build_railway_infrastructure(): grid_map.grid.fill(0) np.random.seed(0) + # Make connection with dead-ends on both sides start_point = (2, 2) end_point = (8, 8) - connection_001 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_001 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=True, + flip_end_node_trans=True, respect_transition_validity=True, + forbidden_cells=None) connection_001_expected = [(2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (3, 8), (4, 8), (5, 8), (6, 8), (7, 8), (8, 8)] + # Make connection with open ends on both sides start_point = (1, 3) end_point = (1, 7) - connection_002 = connect_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_002 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=False, + flip_end_node_trans=False, respect_transition_validity=True, + forbidden_cells=None) connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] + # Make connection with open end at beginning and dead end on end start_point = (6, 2) end_point = (6, 5) - connection_003 = connect_from_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_003 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=False, + flip_end_node_trans=True, respect_transition_validity=True, + forbidden_cells=None) connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] + # Make connection with dead end on start and opend end start_point = (7, 5) end_point = (8, 9) - connection_004 = connect_to_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_004 = connect_rail_in_grid_map(grid_map, start_point, end_point, rail_trans, flip_start_node_trans=True, + flip_end_node_trans=False, respect_transition_validity=True, + forbidden_cells=None) connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] assert connection_001 == connection_001_expected, \ @@ -64,6 +75,5 @@ def test_build_railway_infrastructure(): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ] - - assert np.all(grid_map.grid == grid_map_grid_expected), \ - "actual={}, expected={}".format(grid_map.grid, grid_map_grid_expected) + for i in range(len(grid_map_grid_expected)): + assert np.all(grid_map.grid[i] == grid_map_grid_expected[i]) diff --git a/tests/test_flatland_envs_city_generator.py b/tests/test_flatland_envs_city_generator.py deleted file mode 100644 index 1d386df225e7d025116752e26d5c55cf2a292214..0000000000000000000000000000000000000000 --- a/tests/test_flatland_envs_city_generator.py +++ /dev/null @@ -1,301 +0,0 @@ -import numpy as np - -from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d -from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators_city_generator import city_generator -from flatland.envs.schedule_generators import city_schedule_generator - - -def test_city_generator(): - dist_fun = Vec2d.get_manhattan_distance - env = RailEnv(width=50, - height=50, - rail_generator=city_generator(num_cities=5, - city_size=10, - allowed_rotation_angles=[90], - max_number_of_station_tracks=4, - nbr_of_switches_per_station_track=2, - connect_max_nbr_of_shortes_city=2, - do_random_connect_stations=False, - a_star_distance_function=dist_fun, - seed=0, - print_out_info=False - ), - schedule_generator=city_schedule_generator(), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) - - expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type()) - - expected_grid_map[8][16] = 4 - expected_grid_map[8][17] = 5633 - expected_grid_map[8][18] = 1025 - expected_grid_map[8][19] = 1025 - expected_grid_map[8][20] = 17411 - expected_grid_map[8][21] = 1025 - expected_grid_map[8][22] = 1025 - expected_grid_map[8][23] = 1025 - expected_grid_map[8][24] = 1025 - expected_grid_map[8][25] = 1025 - expected_grid_map[8][26] = 4608 - expected_grid_map[9][16] = 16386 - expected_grid_map[9][17] = 50211 - expected_grid_map[9][18] = 1025 - expected_grid_map[9][19] = 1025 - expected_grid_map[9][20] = 3089 - expected_grid_map[9][21] = 1025 - expected_grid_map[9][22] = 256 - expected_grid_map[9][26] = 32800 - expected_grid_map[10][6] = 16386 - expected_grid_map[10][7] = 1025 - expected_grid_map[10][8] = 1025 - expected_grid_map[10][9] = 1025 - expected_grid_map[10][10] = 1025 - expected_grid_map[10][11] = 1025 - expected_grid_map[10][12] = 1025 - expected_grid_map[10][13] = 1025 - expected_grid_map[10][14] = 1025 - expected_grid_map[10][15] = 1025 - expected_grid_map[10][16] = 33825 - expected_grid_map[10][17] = 34864 - expected_grid_map[10][26] = 32800 - expected_grid_map[11][6] = 32800 - expected_grid_map[11][16] = 32800 - expected_grid_map[11][17] = 32800 - expected_grid_map[11][26] = 32800 - expected_grid_map[12][6] = 32800 - expected_grid_map[12][16] = 32800 - expected_grid_map[12][17] = 32800 - expected_grid_map[12][26] = 32800 - expected_grid_map[13][6] = 32800 - expected_grid_map[13][16] = 32800 - expected_grid_map[13][17] = 32800 - expected_grid_map[13][26] = 32800 - expected_grid_map[14][6] = 32800 - expected_grid_map[14][16] = 32800 - expected_grid_map[14][17] = 32800 - expected_grid_map[14][26] = 32800 - expected_grid_map[15][6] = 32800 - expected_grid_map[15][16] = 32800 - expected_grid_map[15][17] = 32800 - expected_grid_map[15][26] = 32800 - expected_grid_map[16][6] = 32800 - expected_grid_map[16][16] = 32800 - expected_grid_map[16][17] = 32800 - expected_grid_map[16][26] = 32800 - expected_grid_map[17][6] = 32800 - expected_grid_map[17][16] = 72 - expected_grid_map[17][17] = 1097 - expected_grid_map[17][18] = 1025 - expected_grid_map[17][19] = 1025 - expected_grid_map[17][20] = 1025 - expected_grid_map[17][21] = 1025 - expected_grid_map[17][22] = 1025 - expected_grid_map[17][23] = 1025 - expected_grid_map[17][24] = 1025 - expected_grid_map[17][25] = 1025 - expected_grid_map[17][26] = 33825 - expected_grid_map[17][27] = 4608 - expected_grid_map[18][6] = 32800 - expected_grid_map[18][26] = 72 - expected_grid_map[18][27] = 52275 - expected_grid_map[18][28] = 5633 - expected_grid_map[18][29] = 17411 - expected_grid_map[18][30] = 1025 - expected_grid_map[18][31] = 1025 - expected_grid_map[18][32] = 256 - expected_grid_map[19][6] = 32800 - expected_grid_map[19][25] = 16386 - expected_grid_map[19][26] = 1025 - expected_grid_map[19][27] = 2136 - expected_grid_map[19][28] = 1097 - expected_grid_map[19][29] = 1097 - expected_grid_map[19][30] = 5633 - expected_grid_map[19][31] = 1025 - expected_grid_map[19][32] = 256 - expected_grid_map[20][6] = 32800 - expected_grid_map[20][25] = 32800 - expected_grid_map[20][26] = 16386 - expected_grid_map[20][27] = 17411 - expected_grid_map[20][28] = 1025 - expected_grid_map[20][29] = 1025 - expected_grid_map[20][30] = 3089 - expected_grid_map[20][31] = 1025 - expected_grid_map[20][32] = 256 - expected_grid_map[21][6] = 32800 - expected_grid_map[21][16] = 16386 - expected_grid_map[21][17] = 1025 - expected_grid_map[21][18] = 1025 - expected_grid_map[21][19] = 1025 - expected_grid_map[21][20] = 1025 - expected_grid_map[21][21] = 1025 - expected_grid_map[21][22] = 1025 - expected_grid_map[21][23] = 1025 - expected_grid_map[21][24] = 1025 - expected_grid_map[21][25] = 33825 - expected_grid_map[21][26] = 33825 - expected_grid_map[21][27] = 2064 - expected_grid_map[22][6] = 32800 - expected_grid_map[22][16] = 32800 - expected_grid_map[22][25] = 32800 - expected_grid_map[22][26] = 32800 - expected_grid_map[23][6] = 32800 - expected_grid_map[23][16] = 32800 - expected_grid_map[23][25] = 32800 - expected_grid_map[23][26] = 32800 - expected_grid_map[24][6] = 32800 - expected_grid_map[24][16] = 32800 - expected_grid_map[24][25] = 32800 - expected_grid_map[24][26] = 32800 - expected_grid_map[25][6] = 32800 - expected_grid_map[25][16] = 32800 - expected_grid_map[25][25] = 32800 - expected_grid_map[25][26] = 32800 - expected_grid_map[26][6] = 32800 - expected_grid_map[26][16] = 32800 - expected_grid_map[26][25] = 32800 - expected_grid_map[26][26] = 32800 - expected_grid_map[27][6] = 72 - expected_grid_map[27][7] = 1025 - expected_grid_map[27][8] = 1025 - expected_grid_map[27][9] = 17411 - expected_grid_map[27][10] = 1025 - expected_grid_map[27][11] = 1025 - expected_grid_map[27][12] = 1025 - expected_grid_map[27][13] = 1025 - expected_grid_map[27][14] = 1025 - expected_grid_map[27][15] = 4608 - expected_grid_map[27][16] = 72 - expected_grid_map[27][17] = 17411 - expected_grid_map[27][18] = 5633 - expected_grid_map[27][19] = 1025 - expected_grid_map[27][20] = 1025 - expected_grid_map[27][21] = 1025 - expected_grid_map[27][22] = 1025 - expected_grid_map[27][23] = 1025 - expected_grid_map[27][24] = 1025 - expected_grid_map[27][25] = 33825 - expected_grid_map[27][26] = 2064 - expected_grid_map[28][6] = 4 - expected_grid_map[28][7] = 1025 - expected_grid_map[28][8] = 1025 - expected_grid_map[28][9] = 3089 - expected_grid_map[28][10] = 1025 - expected_grid_map[28][11] = 1025 - expected_grid_map[28][12] = 1025 - expected_grid_map[28][13] = 1025 - expected_grid_map[28][14] = 4608 - expected_grid_map[28][15] = 72 - expected_grid_map[28][16] = 1025 - expected_grid_map[28][17] = 2136 - expected_grid_map[28][18] = 1097 - expected_grid_map[28][19] = 5633 - expected_grid_map[28][20] = 5633 - expected_grid_map[28][21] = 1025 - expected_grid_map[28][22] = 256 - expected_grid_map[28][25] = 32800 - expected_grid_map[29][6] = 4 - expected_grid_map[29][7] = 5633 - expected_grid_map[29][8] = 20994 - expected_grid_map[29][9] = 5633 - expected_grid_map[29][10] = 1025 - expected_grid_map[29][11] = 1025 - expected_grid_map[29][12] = 1025 - expected_grid_map[29][13] = 1025 - expected_grid_map[29][14] = 1097 - expected_grid_map[29][15] = 5633 - expected_grid_map[29][16] = 1025 - expected_grid_map[29][17] = 17411 - expected_grid_map[29][18] = 5633 - expected_grid_map[29][19] = 1097 - expected_grid_map[29][20] = 3089 - expected_grid_map[29][21] = 20994 - expected_grid_map[29][22] = 1025 - expected_grid_map[29][23] = 1025 - expected_grid_map[29][24] = 1025 - expected_grid_map[29][25] = 2064 - expected_grid_map[30][6] = 16386 - expected_grid_map[30][7] = 38505 - expected_grid_map[30][8] = 3089 - expected_grid_map[30][9] = 1097 - expected_grid_map[30][10] = 1025 - expected_grid_map[30][11] = 1025 - expected_grid_map[30][12] = 256 - expected_grid_map[30][15] = 32800 - expected_grid_map[30][16] = 16386 - expected_grid_map[30][17] = 52275 - expected_grid_map[30][18] = 1097 - expected_grid_map[30][19] = 1025 - expected_grid_map[30][20] = 1025 - expected_grid_map[30][21] = 3089 - expected_grid_map[30][22] = 256 - expected_grid_map[31][6] = 32800 - expected_grid_map[31][7] = 32800 - expected_grid_map[31][15] = 72 - expected_grid_map[31][16] = 37408 - expected_grid_map[31][17] = 32800 - expected_grid_map[32][6] = 32800 - expected_grid_map[32][7] = 32800 - expected_grid_map[32][16] = 32800 - expected_grid_map[32][17] = 32800 - expected_grid_map[33][6] = 32800 - expected_grid_map[33][7] = 32800 - expected_grid_map[33][16] = 32800 - expected_grid_map[33][17] = 32800 - expected_grid_map[34][6] = 32800 - expected_grid_map[34][7] = 32800 - expected_grid_map[34][16] = 32800 - expected_grid_map[34][17] = 32800 - expected_grid_map[35][6] = 32800 - expected_grid_map[35][7] = 32800 - expected_grid_map[35][16] = 32800 - expected_grid_map[35][17] = 32800 - expected_grid_map[36][6] = 32800 - expected_grid_map[36][7] = 32800 - expected_grid_map[36][16] = 32800 - expected_grid_map[36][17] = 32800 - expected_grid_map[37][6] = 72 - expected_grid_map[37][7] = 1097 - expected_grid_map[37][8] = 1025 - expected_grid_map[37][9] = 1025 - expected_grid_map[37][10] = 1025 - expected_grid_map[37][11] = 1025 - expected_grid_map[37][12] = 1025 - expected_grid_map[37][13] = 1025 - expected_grid_map[37][14] = 1025 - expected_grid_map[37][15] = 1025 - expected_grid_map[37][16] = 33897 - expected_grid_map[37][17] = 37408 - expected_grid_map[38][16] = 72 - expected_grid_map[38][17] = 52275 - expected_grid_map[38][18] = 5633 - expected_grid_map[38][19] = 17411 - expected_grid_map[38][20] = 1025 - expected_grid_map[38][21] = 1025 - expected_grid_map[38][22] = 256 - expected_grid_map[39][16] = 4 - expected_grid_map[39][17] = 52275 - expected_grid_map[39][18] = 3089 - expected_grid_map[39][19] = 1097 - expected_grid_map[39][20] = 5633 - expected_grid_map[39][21] = 1025 - expected_grid_map[39][22] = 256 - expected_grid_map[40][16] = 4 - expected_grid_map[40][17] = 1097 - expected_grid_map[40][18] = 1025 - expected_grid_map[40][19] = 1025 - expected_grid_map[40][20] = 3089 - expected_grid_map[40][21] = 1025 - expected_grid_map[40][22] = 256 - - assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid, - expected_grid_map) - - s0 = 0 - s1 = 0 - for a in range(env.get_num_agents()): - s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0)) - s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0)) - assert s0 == 58, "actual={}".format(s0) - assert s1 == 38, "actual={}".format(s1) diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 4600c4a3002995e1238a0ccbda762501ac985408..65d2d68c45155efda24536ecfd776bef5ebaab0c 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -25,6 +25,7 @@ def test_get_shortest_paths_unreachable(): # set the initial position agent = env.agents_static[0] agent.position = (3, 1) # west dead-end + agent.initial_position = (3, 1) # west dead-end agent.direction = Grid4TransitionsEnum.WEST agent.target = (3, 9) # east dead-end agent.moving = True diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 8b2cdbea9431d970778acb8d973cc0002d5a90f5..c4801958d7b0b1218c2169aed3cafee680fad5b4 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -11,709 +11,592 @@ from flatland.utils.rendertools import RenderTool def test_sparse_rail_generator(): + np.random.seed(0) + random.seed(0) env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=10, # Number of interesections in map - num_trainstations=50, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes + rail_generator=sparse_rail_generator(max_num_cities=10, + max_rails_between_cities=3, + seed=5, + grid_mode=False ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + obs_builder_object=GlobalObsForRailEnv() + ) env.reset(False, False, True) + # for r in range(env.height): + # for c in range (env.width): + # if env.rail.grid[r][c] > 0: + # print("expected_grid_map[{}][{}] = {}".format(r,c,env.rail.grid[r][c])) + expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type()) - expected_grid_map[1][33] = 8192 - expected_grid_map[2][33] = 32800 - expected_grid_map[3][31] = 4 - expected_grid_map[3][32] = 4608 - expected_grid_map[3][33] = 32800 - expected_grid_map[4][30] = 16386 - expected_grid_map[4][31] = 17411 - expected_grid_map[4][32] = 1097 - expected_grid_map[4][33] = 38505 + expected_grid_map[4][32] = 16386 + expected_grid_map[4][33] = 1025 expected_grid_map[4][34] = 1025 - expected_grid_map[4][35] = 5633 + expected_grid_map[4][35] = 1025 expected_grid_map[4][36] = 1025 expected_grid_map[4][37] = 1025 - expected_grid_map[4][38] = 1025 + expected_grid_map[4][38] = 17411 expected_grid_map[4][39] = 1025 expected_grid_map[4][40] = 1025 expected_grid_map[4][41] = 1025 - expected_grid_map[4][42] = 1025 - expected_grid_map[4][43] = 1025 - expected_grid_map[4][44] = 1025 - expected_grid_map[4][45] = 1025 - expected_grid_map[4][46] = 1025 - expected_grid_map[4][47] = 1025 - expected_grid_map[4][48] = 4608 - expected_grid_map[5][30] = 128 - expected_grid_map[5][31] = 32800 - expected_grid_map[5][33] = 32800 - expected_grid_map[5][35] = 32800 - expected_grid_map[5][48] = 32800 - expected_grid_map[6][30] = 4 - expected_grid_map[6][31] = 2064 - expected_grid_map[6][33] = 32800 - expected_grid_map[6][35] = 128 - expected_grid_map[6][48] = 32800 - expected_grid_map[7][33] = 32872 - expected_grid_map[7][34] = 1025 - expected_grid_map[7][35] = 1025 - expected_grid_map[7][36] = 1025 - expected_grid_map[7][37] = 1025 - expected_grid_map[7][38] = 1025 - expected_grid_map[7][39] = 1025 - expected_grid_map[7][40] = 1025 - expected_grid_map[7][41] = 20994 - expected_grid_map[7][42] = 1025 - expected_grid_map[7][43] = 1025 - expected_grid_map[7][44] = 1025 - expected_grid_map[7][45] = 1025 - expected_grid_map[7][46] = 1025 - expected_grid_map[7][47] = 4608 - expected_grid_map[7][48] = 32800 - expected_grid_map[8][3] = 16386 - expected_grid_map[8][4] = 1025 - expected_grid_map[8][5] = 1025 - expected_grid_map[8][6] = 1025 - expected_grid_map[8][7] = 1025 - expected_grid_map[8][8] = 1025 - expected_grid_map[8][9] = 1025 - expected_grid_map[8][10] = 1025 - expected_grid_map[8][11] = 5633 - expected_grid_map[8][12] = 1025 - expected_grid_map[8][13] = 1025 - expected_grid_map[8][14] = 1025 - expected_grid_map[8][15] = 1025 - expected_grid_map[8][16] = 1025 - expected_grid_map[8][17] = 1025 - expected_grid_map[8][18] = 4608 - expected_grid_map[8][33] = 32800 + expected_grid_map[4][42] = 5633 + expected_grid_map[4][43] = 5633 + expected_grid_map[4][44] = 4608 + expected_grid_map[5][32] = 32800 + expected_grid_map[5][38] = 32800 + expected_grid_map[5][41] = 16386 + expected_grid_map[5][42] = 38505 + expected_grid_map[5][43] = 38505 + expected_grid_map[5][44] = 34864 + expected_grid_map[6][32] = 32800 + expected_grid_map[6][38] = 32800 + expected_grid_map[6][41] = 32800 + expected_grid_map[6][42] = 32800 + expected_grid_map[6][43] = 32800 + expected_grid_map[6][44] = 32800 + expected_grid_map[7][32] = 32800 + expected_grid_map[7][38] = 32800 + expected_grid_map[7][41] = 32800 + expected_grid_map[7][42] = 32800 + expected_grid_map[7][43] = 32800 + expected_grid_map[7][44] = 32800 + expected_grid_map[8][32] = 32800 + expected_grid_map[8][38] = 32800 expected_grid_map[8][41] = 32800 - expected_grid_map[8][47] = 32800 - expected_grid_map[8][48] = 32800 - expected_grid_map[9][3] = 32800 - expected_grid_map[9][11] = 32800 - expected_grid_map[9][12] = 8192 - expected_grid_map[9][13] = 8192 - expected_grid_map[9][18] = 32800 - expected_grid_map[9][33] = 32800 + expected_grid_map[8][42] = 32800 + expected_grid_map[8][43] = 32800 + expected_grid_map[8][44] = 32800 + expected_grid_map[9][18] = 16386 + expected_grid_map[9][19] = 1025 + expected_grid_map[9][20] = 1025 + expected_grid_map[9][21] = 1025 + expected_grid_map[9][22] = 1025 + expected_grid_map[9][23] = 5633 + expected_grid_map[9][24] = 1025 + expected_grid_map[9][25] = 1025 + expected_grid_map[9][26] = 1025 + expected_grid_map[9][27] = 1025 + expected_grid_map[9][28] = 1025 + expected_grid_map[9][29] = 1025 + expected_grid_map[9][30] = 1025 + expected_grid_map[9][31] = 5633 + expected_grid_map[9][32] = 3089 + expected_grid_map[9][33] = 1025 + expected_grid_map[9][34] = 1025 + expected_grid_map[9][35] = 1025 + expected_grid_map[9][36] = 1025 + expected_grid_map[9][37] = 1025 + expected_grid_map[9][38] = 2064 expected_grid_map[9][41] = 32800 - expected_grid_map[9][47] = 32800 - expected_grid_map[9][48] = 32800 - expected_grid_map[10][3] = 32800 - expected_grid_map[10][8] = 8192 - expected_grid_map[10][11] = 32800 - expected_grid_map[10][12] = 32800 - expected_grid_map[10][13] = 32800 - expected_grid_map[10][18] = 32800 - expected_grid_map[10][33] = 32800 + expected_grid_map[9][42] = 32800 + expected_grid_map[9][43] = 32800 + expected_grid_map[9][44] = 32800 + expected_grid_map[10][18] = 49186 + expected_grid_map[10][19] = 1025 + expected_grid_map[10][20] = 1025 + expected_grid_map[10][21] = 1025 + expected_grid_map[10][22] = 1025 + expected_grid_map[10][23] = 52275 + expected_grid_map[10][24] = 1025 + expected_grid_map[10][25] = 1025 + expected_grid_map[10][26] = 1025 + expected_grid_map[10][27] = 1025 + expected_grid_map[10][28] = 1025 + expected_grid_map[10][29] = 1025 + expected_grid_map[10][30] = 1025 + expected_grid_map[10][31] = 52275 + expected_grid_map[10][32] = 1025 + expected_grid_map[10][33] = 1025 + expected_grid_map[10][34] = 1025 + expected_grid_map[10][35] = 1025 + expected_grid_map[10][36] = 1025 + expected_grid_map[10][37] = 1025 + expected_grid_map[10][38] = 4608 expected_grid_map[10][41] = 32800 - expected_grid_map[10][47] = 32800 - expected_grid_map[10][48] = 32800 - expected_grid_map[11][3] = 32800 - expected_grid_map[11][8] = 32800 - expected_grid_map[11][11] = 32800 - expected_grid_map[11][12] = 32800 - expected_grid_map[11][13] = 32800 - expected_grid_map[11][18] = 32800 - expected_grid_map[11][33] = 32800 + expected_grid_map[10][42] = 32800 + expected_grid_map[10][43] = 32800 + expected_grid_map[10][44] = 32800 + expected_grid_map[11][18] = 49186 + expected_grid_map[11][19] = 1025 + expected_grid_map[11][20] = 1025 + expected_grid_map[11][21] = 1025 + expected_grid_map[11][22] = 17411 + expected_grid_map[11][23] = 1097 + expected_grid_map[11][24] = 1025 + expected_grid_map[11][25] = 1025 + expected_grid_map[11][26] = 1025 + expected_grid_map[11][27] = 1025 + expected_grid_map[11][28] = 1025 + expected_grid_map[11][29] = 1025 + expected_grid_map[11][30] = 1025 + expected_grid_map[11][31] = 3089 + expected_grid_map[11][32] = 1025 + expected_grid_map[11][33] = 1025 + expected_grid_map[11][34] = 1025 + expected_grid_map[11][35] = 1025 + expected_grid_map[11][36] = 1025 + expected_grid_map[11][37] = 1025 + expected_grid_map[11][38] = 37408 expected_grid_map[11][41] = 32800 - expected_grid_map[11][47] = 32800 - expected_grid_map[11][48] = 32800 - expected_grid_map[12][3] = 32800 - expected_grid_map[12][8] = 72 - expected_grid_map[12][9] = 1025 - expected_grid_map[12][10] = 17411 - expected_grid_map[12][11] = 52275 - expected_grid_map[12][12] = 3089 - expected_grid_map[12][13] = 3089 - expected_grid_map[12][14] = 1025 - expected_grid_map[12][15] = 1025 - expected_grid_map[12][16] = 1025 - expected_grid_map[12][17] = 1025 - expected_grid_map[12][18] = 33825 - expected_grid_map[12][19] = 1025 - expected_grid_map[12][20] = 1025 - expected_grid_map[12][21] = 1025 - expected_grid_map[12][22] = 1025 - expected_grid_map[12][23] = 1025 - expected_grid_map[12][24] = 1025 - expected_grid_map[12][25] = 1025 - expected_grid_map[12][26] = 1025 - expected_grid_map[12][27] = 1025 - expected_grid_map[12][28] = 1025 - expected_grid_map[12][29] = 1025 - expected_grid_map[12][30] = 1025 - expected_grid_map[12][31] = 1025 - expected_grid_map[12][32] = 1025 - expected_grid_map[12][33] = 33825 - expected_grid_map[12][34] = 1025 - expected_grid_map[12][35] = 1025 - expected_grid_map[12][36] = 1025 - expected_grid_map[12][37] = 1025 - expected_grid_map[12][38] = 1025 - expected_grid_map[12][39] = 1025 - expected_grid_map[12][40] = 1025 - expected_grid_map[12][41] = 35889 - expected_grid_map[12][42] = 4608 - expected_grid_map[12][47] = 32800 - expected_grid_map[12][48] = 32800 - expected_grid_map[13][3] = 32800 - expected_grid_map[13][10] = 32800 - expected_grid_map[13][11] = 32872 + expected_grid_map[11][42] = 32800 + expected_grid_map[11][43] = 32800 + expected_grid_map[11][44] = 32800 + expected_grid_map[12][18] = 32800 + expected_grid_map[12][22] = 32800 + expected_grid_map[12][38] = 32800 + expected_grid_map[12][41] = 32800 + expected_grid_map[12][42] = 32800 + expected_grid_map[12][43] = 32800 + expected_grid_map[12][44] = 32800 + expected_grid_map[13][6] = 16386 + expected_grid_map[13][7] = 1025 + expected_grid_map[13][8] = 1025 + expected_grid_map[13][9] = 17411 + expected_grid_map[13][10] = 1025 + expected_grid_map[13][11] = 1025 expected_grid_map[13][12] = 1025 - expected_grid_map[13][13] = 256 - expected_grid_map[13][15] = 8192 - expected_grid_map[13][16] = 8192 - expected_grid_map[13][17] = 8192 - expected_grid_map[13][18] = 32800 - expected_grid_map[13][20] = 8192 - expected_grid_map[13][33] = 32800 - expected_grid_map[13][41] = 32800 - expected_grid_map[13][42] = 32800 - expected_grid_map[13][47] = 32800 - expected_grid_map[13][48] = 32800 - expected_grid_map[14][3] = 32800 - expected_grid_map[14][10] = 128 - expected_grid_map[14][11] = 32800 - expected_grid_map[14][15] = 72 - expected_grid_map[14][16] = 37408 - expected_grid_map[14][17] = 32800 - expected_grid_map[14][18] = 32800 - expected_grid_map[14][20] = 32800 - expected_grid_map[14][33] = 32800 - expected_grid_map[14][41] = 32800 - expected_grid_map[14][42] = 32800 - expected_grid_map[14][47] = 32800 - expected_grid_map[14][48] = 32800 - expected_grid_map[15][3] = 32800 - expected_grid_map[15][11] = 32800 - expected_grid_map[15][15] = 4 - expected_grid_map[15][16] = 1097 - expected_grid_map[15][17] = 1097 - expected_grid_map[15][18] = 3089 + expected_grid_map[13][13] = 1025 + expected_grid_map[13][14] = 1025 + expected_grid_map[13][15] = 1025 + expected_grid_map[13][16] = 1025 + expected_grid_map[13][17] = 17411 + expected_grid_map[13][18] = 3089 + expected_grid_map[13][19] = 1025 + expected_grid_map[13][20] = 1025 + expected_grid_map[13][21] = 1025 + expected_grid_map[13][22] = 34864 + expected_grid_map[13][38] = 32800 + expected_grid_map[13][41] = 72 + expected_grid_map[13][42] = 38505 + expected_grid_map[13][43] = 52275 + expected_grid_map[13][44] = 34864 + expected_grid_map[14][6] = 49186 + expected_grid_map[14][7] = 1025 + expected_grid_map[14][8] = 1025 + expected_grid_map[14][9] = 38505 + expected_grid_map[14][10] = 1025 + expected_grid_map[14][11] = 1025 + expected_grid_map[14][12] = 1025 + expected_grid_map[14][13] = 1025 + expected_grid_map[14][14] = 1025 + expected_grid_map[14][15] = 1025 + expected_grid_map[14][16] = 1025 + expected_grid_map[14][17] = 52275 + expected_grid_map[14][18] = 1025 + expected_grid_map[14][19] = 1025 + expected_grid_map[14][20] = 1025 + expected_grid_map[14][21] = 1025 + expected_grid_map[14][22] = 34864 + expected_grid_map[14][38] = 72 + expected_grid_map[14][39] = 1025 + expected_grid_map[14][40] = 1025 + expected_grid_map[14][41] = 1025 + expected_grid_map[14][42] = 2136 + expected_grid_map[14][43] = 37408 + expected_grid_map[14][44] = 32800 + expected_grid_map[15][6] = 49186 + expected_grid_map[15][7] = 1025 + expected_grid_map[15][8] = 17411 + expected_grid_map[15][9] = 1097 + expected_grid_map[15][10] = 1025 + expected_grid_map[15][11] = 1025 + expected_grid_map[15][12] = 1025 + expected_grid_map[15][13] = 1025 + expected_grid_map[15][14] = 1025 + expected_grid_map[15][15] = 1025 + expected_grid_map[15][16] = 1025 + expected_grid_map[15][17] = 3089 + expected_grid_map[15][18] = 1025 expected_grid_map[15][19] = 1025 - expected_grid_map[15][20] = 3089 + expected_grid_map[15][20] = 1025 expected_grid_map[15][21] = 1025 - expected_grid_map[15][22] = 1025 - expected_grid_map[15][23] = 1025 - expected_grid_map[15][24] = 1025 - expected_grid_map[15][25] = 1025 - expected_grid_map[15][26] = 1025 - expected_grid_map[15][27] = 1025 - expected_grid_map[15][28] = 1025 - expected_grid_map[15][29] = 1025 - expected_grid_map[15][30] = 1025 - expected_grid_map[15][31] = 1025 - expected_grid_map[15][32] = 1025 - expected_grid_map[15][33] = 33825 - expected_grid_map[15][34] = 1025 - expected_grid_map[15][35] = 1025 - expected_grid_map[15][36] = 1025 - expected_grid_map[15][37] = 1025 - expected_grid_map[15][38] = 1025 - expected_grid_map[15][39] = 1025 - expected_grid_map[15][40] = 1025 - expected_grid_map[15][41] = 35889 - expected_grid_map[15][42] = 37408 - expected_grid_map[15][47] = 32800 - expected_grid_map[15][48] = 32800 - expected_grid_map[16][3] = 32800 - expected_grid_map[16][7] = 8192 - expected_grid_map[16][11] = 32800 - expected_grid_map[16][33] = 32800 - expected_grid_map[16][41] = 32800 - expected_grid_map[16][42] = 32800 - expected_grid_map[16][47] = 32800 - expected_grid_map[16][48] = 32800 - expected_grid_map[17][3] = 32800 - expected_grid_map[17][7] = 32800 - expected_grid_map[17][9] = 8192 - expected_grid_map[17][10] = 8192 - expected_grid_map[17][11] = 32800 - expected_grid_map[17][33] = 32800 - expected_grid_map[17][41] = 32800 - expected_grid_map[17][42] = 32800 - expected_grid_map[17][47] = 32800 - expected_grid_map[17][48] = 32800 - expected_grid_map[18][3] = 32800 - expected_grid_map[18][7] = 32800 - expected_grid_map[18][8] = 8192 - expected_grid_map[18][9] = 32800 - expected_grid_map[18][10] = 32800 - expected_grid_map[18][11] = 32800 - expected_grid_map[18][33] = 32800 - expected_grid_map[18][41] = 32800 - expected_grid_map[18][42] = 32800 - expected_grid_map[18][47] = 32800 - expected_grid_map[18][48] = 32800 - expected_grid_map[19][3] = 72 - expected_grid_map[19][4] = 1025 - expected_grid_map[19][5] = 1025 - expected_grid_map[19][6] = 1025 - expected_grid_map[19][7] = 1097 - expected_grid_map[19][8] = 1097 - expected_grid_map[19][9] = 1097 - expected_grid_map[19][10] = 52275 - expected_grid_map[19][11] = 33825 - expected_grid_map[19][12] = 1025 - expected_grid_map[19][13] = 1025 - expected_grid_map[19][14] = 4608 - expected_grid_map[19][33] = 32800 - expected_grid_map[19][41] = 32800 - expected_grid_map[19][42] = 32800 - expected_grid_map[19][47] = 32800 - expected_grid_map[19][48] = 32800 - expected_grid_map[20][7] = 4 - expected_grid_map[20][8] = 1025 - expected_grid_map[20][9] = 1025 - expected_grid_map[20][10] = 34864 - expected_grid_map[20][11] = 32800 - expected_grid_map[20][14] = 32800 - expected_grid_map[20][33] = 32800 - expected_grid_map[20][41] = 32800 + expected_grid_map[15][22] = 38505 + expected_grid_map[15][23] = 4608 + expected_grid_map[15][43] = 32800 + expected_grid_map[15][44] = 32800 + expected_grid_map[16][6] = 32800 + expected_grid_map[16][8] = 32800 + expected_grid_map[16][22] = 32800 + expected_grid_map[16][23] = 32800 + expected_grid_map[16][43] = 32800 + expected_grid_map[16][44] = 32800 + expected_grid_map[17][6] = 32800 + expected_grid_map[17][8] = 32800 + expected_grid_map[17][22] = 32800 + expected_grid_map[17][23] = 32800 + expected_grid_map[17][43] = 32800 + expected_grid_map[17][44] = 32800 + expected_grid_map[18][6] = 32800 + expected_grid_map[18][8] = 32800 + expected_grid_map[18][22] = 32800 + expected_grid_map[18][23] = 32800 + expected_grid_map[18][43] = 32800 + expected_grid_map[18][44] = 32872 + expected_grid_map[18][45] = 4608 + expected_grid_map[19][6] = 32800 + expected_grid_map[19][8] = 32800 + expected_grid_map[19][22] = 32800 + expected_grid_map[19][23] = 32800 + expected_grid_map[19][42] = 16386 + expected_grid_map[19][43] = 52275 + expected_grid_map[19][44] = 38505 + expected_grid_map[19][45] = 34864 + expected_grid_map[20][4] = 16386 + expected_grid_map[20][5] = 17411 + expected_grid_map[20][6] = 50211 + expected_grid_map[20][7] = 1025 + expected_grid_map[20][8] = 2064 + expected_grid_map[20][22] = 32800 + expected_grid_map[20][23] = 32800 expected_grid_map[20][42] = 32800 - expected_grid_map[20][47] = 32800 - expected_grid_map[20][48] = 32800 - expected_grid_map[21][10] = 32800 - expected_grid_map[21][11] = 32800 - expected_grid_map[21][14] = 32800 - expected_grid_map[21][24] = 8192 - expected_grid_map[21][33] = 32872 - expected_grid_map[21][34] = 1025 - expected_grid_map[21][35] = 1025 - expected_grid_map[21][36] = 1025 - expected_grid_map[21][37] = 1025 - expected_grid_map[21][38] = 1025 - expected_grid_map[21][39] = 1025 - expected_grid_map[21][40] = 1025 - expected_grid_map[21][41] = 33825 - expected_grid_map[21][42] = 38505 - expected_grid_map[21][43] = 1025 - expected_grid_map[21][44] = 1025 - expected_grid_map[21][45] = 1025 - expected_grid_map[21][46] = 1025 - expected_grid_map[21][47] = 37408 - expected_grid_map[21][48] = 32800 - expected_grid_map[22][10] = 32800 - expected_grid_map[22][11] = 32800 - expected_grid_map[22][14] = 32800 - expected_grid_map[22][22] = 8192 - expected_grid_map[22][24] = 32800 - expected_grid_map[22][27] = 8192 - expected_grid_map[22][33] = 32800 - expected_grid_map[22][41] = 32800 + expected_grid_map[20][43] = 32800 + expected_grid_map[20][44] = 32800 + expected_grid_map[20][45] = 32800 + expected_grid_map[21][3] = 16386 + expected_grid_map[21][4] = 38505 + expected_grid_map[21][5] = 52275 + expected_grid_map[21][6] = 37408 + expected_grid_map[21][22] = 32800 + expected_grid_map[21][23] = 32800 + expected_grid_map[21][42] = 32800 + expected_grid_map[21][43] = 32800 + expected_grid_map[21][44] = 32800 + expected_grid_map[21][45] = 32800 + expected_grid_map[22][3] = 32800 + expected_grid_map[22][4] = 32800 + expected_grid_map[22][5] = 32800 + expected_grid_map[22][6] = 32800 + expected_grid_map[22][21] = 16386 + expected_grid_map[22][22] = 34864 + expected_grid_map[22][23] = 32872 + expected_grid_map[22][24] = 4608 expected_grid_map[22][42] = 32800 - expected_grid_map[22][47] = 32800 - expected_grid_map[22][48] = 32800 - expected_grid_map[23][10] = 32800 - expected_grid_map[23][11] = 32800 - expected_grid_map[23][14] = 32800 - expected_grid_map[23][22] = 72 - expected_grid_map[23][23] = 17411 - expected_grid_map[23][24] = 1097 - expected_grid_map[23][25] = 17411 - expected_grid_map[23][26] = 1025 - expected_grid_map[23][27] = 3089 - expected_grid_map[23][28] = 1025 - expected_grid_map[23][29] = 1025 - expected_grid_map[23][30] = 1025 - expected_grid_map[23][31] = 1025 - expected_grid_map[23][32] = 1025 - expected_grid_map[23][33] = 33825 - expected_grid_map[23][34] = 1025 - expected_grid_map[23][35] = 1025 - expected_grid_map[23][36] = 1025 - expected_grid_map[23][37] = 1025 - expected_grid_map[23][38] = 1025 - expected_grid_map[23][39] = 1025 - expected_grid_map[23][40] = 1025 - expected_grid_map[23][41] = 3089 - expected_grid_map[23][42] = 34864 - expected_grid_map[23][47] = 32800 - expected_grid_map[23][48] = 32800 - expected_grid_map[24][10] = 32800 - expected_grid_map[24][11] = 32800 - expected_grid_map[24][14] = 32800 + expected_grid_map[22][43] = 32800 + expected_grid_map[22][44] = 32800 + expected_grid_map[22][45] = 32800 + expected_grid_map[23][3] = 32800 + expected_grid_map[23][4] = 32800 + expected_grid_map[23][5] = 32800 + expected_grid_map[23][6] = 32800 + expected_grid_map[23][21] = 32872 + expected_grid_map[23][22] = 38505 + expected_grid_map[23][23] = 38505 + expected_grid_map[23][24] = 37408 + expected_grid_map[23][42] = 32800 + expected_grid_map[23][43] = 32800 + expected_grid_map[23][44] = 32800 + expected_grid_map[23][45] = 32800 + expected_grid_map[24][3] = 32800 + expected_grid_map[24][4] = 32800 + expected_grid_map[24][5] = 32800 + expected_grid_map[24][6] = 32800 + expected_grid_map[24][21] = 32800 + expected_grid_map[24][22] = 32800 expected_grid_map[24][23] = 32800 - expected_grid_map[24][24] = 4 - expected_grid_map[24][25] = 34864 - expected_grid_map[24][33] = 32800 + expected_grid_map[24][24] = 32800 expected_grid_map[24][42] = 32800 - expected_grid_map[24][47] = 32800 - expected_grid_map[24][48] = 32800 - expected_grid_map[25][10] = 32800 - expected_grid_map[25][11] = 32800 - expected_grid_map[25][14] = 32800 - expected_grid_map[25][23] = 128 - expected_grid_map[25][25] = 32800 - expected_grid_map[25][33] = 32800 + expected_grid_map[24][43] = 32800 + expected_grid_map[24][44] = 32800 + expected_grid_map[24][45] = 32800 + expected_grid_map[25][3] = 32800 + expected_grid_map[25][4] = 32800 + expected_grid_map[25][5] = 32800 + expected_grid_map[25][6] = 32800 + expected_grid_map[25][21] = 32800 + expected_grid_map[25][22] = 32800 + expected_grid_map[25][23] = 32800 + expected_grid_map[25][24] = 32800 expected_grid_map[25][42] = 32800 - expected_grid_map[25][47] = 32800 - expected_grid_map[25][48] = 32800 - expected_grid_map[26][10] = 32800 - expected_grid_map[26][11] = 32800 - expected_grid_map[26][14] = 32800 - expected_grid_map[26][25] = 32800 - expected_grid_map[26][33] = 32800 + expected_grid_map[25][43] = 32800 + expected_grid_map[25][44] = 32800 + expected_grid_map[25][45] = 32800 + expected_grid_map[26][3] = 32800 + expected_grid_map[26][4] = 32800 + expected_grid_map[26][5] = 32800 + expected_grid_map[26][6] = 32800 + expected_grid_map[26][21] = 32800 + expected_grid_map[26][22] = 32800 + expected_grid_map[26][23] = 32800 + expected_grid_map[26][24] = 32800 expected_grid_map[26][42] = 32800 - expected_grid_map[26][47] = 32800 - expected_grid_map[26][48] = 32800 - expected_grid_map[27][10] = 32800 - expected_grid_map[27][11] = 32800 - expected_grid_map[27][14] = 32800 - expected_grid_map[27][25] = 32800 - expected_grid_map[27][33] = 32800 - expected_grid_map[27][42] = 32800 - expected_grid_map[27][47] = 32800 - expected_grid_map[27][48] = 32800 - expected_grid_map[28][10] = 32800 - expected_grid_map[28][11] = 32800 - expected_grid_map[28][14] = 32800 - expected_grid_map[28][25] = 32800 - expected_grid_map[28][33] = 49186 - expected_grid_map[28][34] = 256 - expected_grid_map[28][42] = 32800 - expected_grid_map[28][44] = 8192 - expected_grid_map[28][45] = 8192 - expected_grid_map[28][47] = 32800 - expected_grid_map[28][48] = 32800 - expected_grid_map[28][49] = 8192 - expected_grid_map[29][10] = 32800 - expected_grid_map[29][11] = 32800 - expected_grid_map[29][14] = 32800 - expected_grid_map[29][25] = 32800 - expected_grid_map[29][32] = 16386 - expected_grid_map[29][33] = 37408 - expected_grid_map[29][34] = 8192 - expected_grid_map[29][42] = 32800 - expected_grid_map[29][44] = 72 - expected_grid_map[29][45] = 37408 - expected_grid_map[29][47] = 32800 + expected_grid_map[26][43] = 32800 + expected_grid_map[26][44] = 32800 + expected_grid_map[26][45] = 32800 + expected_grid_map[27][3] = 32800 + expected_grid_map[27][4] = 32800 + expected_grid_map[27][5] = 32800 + expected_grid_map[27][6] = 32800 + expected_grid_map[27][21] = 32800 + expected_grid_map[27][22] = 32800 + expected_grid_map[27][23] = 32800 + expected_grid_map[27][24] = 32800 + expected_grid_map[27][42] = 72 + expected_grid_map[27][43] = 52275 + expected_grid_map[27][44] = 38505 + expected_grid_map[27][45] = 34864 + expected_grid_map[28][3] = 32800 + expected_grid_map[28][4] = 32800 + expected_grid_map[28][5] = 32800 + expected_grid_map[28][6] = 32800 + expected_grid_map[28][21] = 32800 + expected_grid_map[28][22] = 32800 + expected_grid_map[28][23] = 32800 + expected_grid_map[28][24] = 32800 + expected_grid_map[28][40] = 16386 + expected_grid_map[28][41] = 1025 + expected_grid_map[28][42] = 1025 + expected_grid_map[28][43] = 2064 + expected_grid_map[28][44] = 72 + expected_grid_map[28][45] = 33897 + expected_grid_map[28][46] = 1025 + expected_grid_map[28][47] = 1025 + expected_grid_map[28][48] = 4608 + expected_grid_map[29][3] = 72 + expected_grid_map[29][4] = 52275 + expected_grid_map[29][5] = 52275 + expected_grid_map[29][6] = 37408 + expected_grid_map[29][21] = 32800 + expected_grid_map[29][22] = 32800 + expected_grid_map[29][23] = 32800 + expected_grid_map[29][24] = 32800 + expected_grid_map[29][40] = 32800 + expected_grid_map[29][45] = 32800 expected_grid_map[29][48] = 32800 - expected_grid_map[29][49] = 32800 - expected_grid_map[30][10] = 32800 - expected_grid_map[30][11] = 32800 - expected_grid_map[30][14] = 32800 - expected_grid_map[30][25] = 32800 - expected_grid_map[30][32] = 128 - expected_grid_map[30][33] = 49186 - expected_grid_map[30][34] = 33825 - expected_grid_map[30][35] = 1025 - expected_grid_map[30][36] = 1025 - expected_grid_map[30][37] = 1025 - expected_grid_map[30][38] = 1025 - expected_grid_map[30][39] = 5633 - expected_grid_map[30][40] = 1025 - expected_grid_map[30][41] = 1025 - expected_grid_map[30][42] = 2064 - expected_grid_map[30][45] = 16458 - expected_grid_map[30][46] = 17411 - expected_grid_map[30][47] = 38505 - expected_grid_map[30][48] = 38433 - expected_grid_map[30][49] = 2064 - expected_grid_map[31][10] = 32800 - expected_grid_map[31][11] = 32800 - expected_grid_map[31][14] = 32800 - expected_grid_map[31][25] = 32800 - expected_grid_map[31][30] = 8192 - expected_grid_map[31][31] = 4 - expected_grid_map[31][32] = 17411 - expected_grid_map[31][33] = 34864 - expected_grid_map[31][34] = 32800 - expected_grid_map[31][39] = 32800 + expected_grid_map[30][4] = 72 + expected_grid_map[30][5] = 1097 + expected_grid_map[30][6] = 37408 + expected_grid_map[30][21] = 32800 + expected_grid_map[30][22] = 32800 + expected_grid_map[30][23] = 32800 + expected_grid_map[30][24] = 32800 + expected_grid_map[30][40] = 32800 + expected_grid_map[30][45] = 32800 + expected_grid_map[30][48] = 32800 + expected_grid_map[31][6] = 32872 + expected_grid_map[31][7] = 5633 + expected_grid_map[31][8] = 4608 + expected_grid_map[31][21] = 49186 + expected_grid_map[31][22] = 52275 + expected_grid_map[31][23] = 38505 + expected_grid_map[31][24] = 34864 + expected_grid_map[31][40] = 32800 expected_grid_map[31][45] = 32800 - expected_grid_map[31][46] = 32800 - expected_grid_map[31][47] = 32800 expected_grid_map[31][48] = 32800 - expected_grid_map[32][10] = 32800 - expected_grid_map[32][11] = 32800 - expected_grid_map[32][14] = 32800 - expected_grid_map[32][25] = 32800 - expected_grid_map[32][30] = 72 - expected_grid_map[32][31] = 1025 - expected_grid_map[32][32] = 2064 - expected_grid_map[32][33] = 32872 - expected_grid_map[32][34] = 2064 - expected_grid_map[32][39] = 32800 - expected_grid_map[32][45] = 128 - expected_grid_map[32][46] = 128 - expected_grid_map[32][47] = 32800 + expected_grid_map[32][6] = 32872 + expected_grid_map[32][7] = 52275 + expected_grid_map[32][8] = 37408 + expected_grid_map[32][21] = 72 + expected_grid_map[32][22] = 1097 + expected_grid_map[32][23] = 37408 + expected_grid_map[32][24] = 32800 + expected_grid_map[32][40] = 32800 + expected_grid_map[32][45] = 32800 expected_grid_map[32][48] = 32800 - expected_grid_map[33][10] = 32800 - expected_grid_map[33][11] = 32800 - expected_grid_map[33][14] = 32872 - expected_grid_map[33][15] = 1025 - expected_grid_map[33][16] = 1025 - expected_grid_map[33][17] = 1025 - expected_grid_map[33][18] = 1025 - expected_grid_map[33][19] = 1025 - expected_grid_map[33][20] = 1025 - expected_grid_map[33][21] = 1025 - expected_grid_map[33][22] = 1025 - expected_grid_map[33][23] = 1025 - expected_grid_map[33][24] = 1025 - expected_grid_map[33][25] = 35889 - expected_grid_map[33][26] = 1025 - expected_grid_map[33][27] = 1025 - expected_grid_map[33][28] = 1025 - expected_grid_map[33][29] = 1025 - expected_grid_map[33][30] = 1025 - expected_grid_map[33][31] = 1025 - expected_grid_map[33][32] = 1025 - expected_grid_map[33][33] = 34864 - expected_grid_map[33][39] = 32800 - expected_grid_map[33][47] = 32800 + expected_grid_map[33][6] = 32800 + expected_grid_map[33][7] = 32800 + expected_grid_map[33][8] = 32800 + expected_grid_map[33][23] = 32800 + expected_grid_map[33][24] = 32800 + expected_grid_map[33][40] = 32800 + expected_grid_map[33][45] = 32800 expected_grid_map[33][48] = 32800 - expected_grid_map[34][5] = 16386 - expected_grid_map[34][6] = 1025 - expected_grid_map[34][7] = 1025 - expected_grid_map[34][8] = 1025 - expected_grid_map[34][9] = 1025 - expected_grid_map[34][10] = 33825 - expected_grid_map[34][11] = 3089 - expected_grid_map[34][12] = 1025 - expected_grid_map[34][13] = 1025 - expected_grid_map[34][14] = 33825 - expected_grid_map[34][15] = 1025 - expected_grid_map[34][16] = 1025 - expected_grid_map[34][17] = 1025 - expected_grid_map[34][18] = 1025 - expected_grid_map[34][19] = 1025 - expected_grid_map[34][20] = 1025 - expected_grid_map[34][21] = 1025 - expected_grid_map[34][22] = 1025 - expected_grid_map[34][23] = 1025 - expected_grid_map[34][24] = 1025 - expected_grid_map[34][25] = 2064 - expected_grid_map[34][33] = 32800 - expected_grid_map[34][39] = 32800 - expected_grid_map[34][47] = 32800 + expected_grid_map[34][6] = 32800 + expected_grid_map[34][7] = 32800 + expected_grid_map[34][8] = 32800 + expected_grid_map[34][23] = 32800 + expected_grid_map[34][24] = 32800 + expected_grid_map[34][40] = 32800 + expected_grid_map[34][45] = 32800 expected_grid_map[34][48] = 32800 - expected_grid_map[35][5] = 32800 - expected_grid_map[35][10] = 32800 - expected_grid_map[35][14] = 32800 - expected_grid_map[35][16] = 8192 - expected_grid_map[35][33] = 32800 - expected_grid_map[35][39] = 32800 - expected_grid_map[35][47] = 32800 + expected_grid_map[35][6] = 32800 + expected_grid_map[35][7] = 32800 + expected_grid_map[35][8] = 32800 + expected_grid_map[35][23] = 72 + expected_grid_map[35][24] = 37408 + expected_grid_map[35][40] = 32800 + expected_grid_map[35][45] = 32800 expected_grid_map[35][48] = 32800 - expected_grid_map[36][5] = 32800 - expected_grid_map[36][10] = 32800 - expected_grid_map[36][14] = 32800 - expected_grid_map[36][16] = 32800 - expected_grid_map[36][17] = 8192 - expected_grid_map[36][19] = 8192 - expected_grid_map[36][33] = 32800 - expected_grid_map[36][39] = 32800 - expected_grid_map[36][41] = 8192 - expected_grid_map[36][47] = 32800 - expected_grid_map[36][48] = 32800 - expected_grid_map[37][5] = 32800 - expected_grid_map[37][10] = 32800 - expected_grid_map[37][14] = 32800 - expected_grid_map[37][16] = 32800 - expected_grid_map[37][17] = 49186 - expected_grid_map[37][18] = 1025 - expected_grid_map[37][19] = 2064 - expected_grid_map[37][33] = 32800 + expected_grid_map[36][6] = 32800 + expected_grid_map[36][7] = 32800 + expected_grid_map[36][8] = 32800 + expected_grid_map[36][24] = 32800 + expected_grid_map[36][39] = 16386 + expected_grid_map[36][40] = 2064 + expected_grid_map[36][45] = 72 + expected_grid_map[36][46] = 1025 + expected_grid_map[36][47] = 1025 + expected_grid_map[36][48] = 1097 + expected_grid_map[36][49] = 4608 + expected_grid_map[37][6] = 32800 + expected_grid_map[37][7] = 32800 + expected_grid_map[37][8] = 32800 + expected_grid_map[37][24] = 32800 expected_grid_map[37][39] = 32800 - expected_grid_map[37][41] = 32800 - expected_grid_map[37][42] = 16386 - expected_grid_map[37][43] = 256 - expected_grid_map[37][47] = 32800 - expected_grid_map[37][48] = 32800 - expected_grid_map[38][5] = 72 - expected_grid_map[38][6] = 1025 - expected_grid_map[38][7] = 1025 - expected_grid_map[38][8] = 1025 - expected_grid_map[38][9] = 1025 - expected_grid_map[38][10] = 33825 - expected_grid_map[38][11] = 1025 - expected_grid_map[38][12] = 1025 + expected_grid_map[37][49] = 32800 + expected_grid_map[38][6] = 32800 + expected_grid_map[38][7] = 32800 + expected_grid_map[38][8] = 32800 + expected_grid_map[38][12] = 16386 expected_grid_map[38][13] = 1025 - expected_grid_map[38][14] = 33897 - expected_grid_map[38][15] = 17411 - expected_grid_map[38][16] = 1097 - expected_grid_map[38][17] = 38505 - expected_grid_map[38][18] = 256 - expected_grid_map[38][33] = 32800 + expected_grid_map[38][14] = 1025 + expected_grid_map[38][15] = 5633 + expected_grid_map[38][16] = 1025 + expected_grid_map[38][17] = 1025 + expected_grid_map[38][18] = 1025 + expected_grid_map[38][19] = 1025 + expected_grid_map[38][20] = 1025 + expected_grid_map[38][21] = 1025 + expected_grid_map[38][22] = 1025 + expected_grid_map[38][23] = 17411 + expected_grid_map[38][24] = 3089 + expected_grid_map[38][25] = 1025 + expected_grid_map[38][26] = 1025 + expected_grid_map[38][27] = 4608 expected_grid_map[38][39] = 32800 - expected_grid_map[38][41] = 32800 - expected_grid_map[38][42] = 32800 - expected_grid_map[38][43] = 8192 - expected_grid_map[38][47] = 32800 - expected_grid_map[38][48] = 32800 - expected_grid_map[39][10] = 32800 - expected_grid_map[39][14] = 32800 - expected_grid_map[39][15] = 32800 - expected_grid_map[39][17] = 32800 - expected_grid_map[39][18] = 4 - expected_grid_map[39][19] = 4608 - expected_grid_map[39][33] = 32800 - expected_grid_map[39][39] = 49186 - expected_grid_map[39][40] = 17411 - expected_grid_map[39][41] = 1097 - expected_grid_map[39][42] = 52275 - expected_grid_map[39][43] = 3089 - expected_grid_map[39][44] = 4608 - expected_grid_map[39][47] = 32800 - expected_grid_map[39][48] = 32800 - expected_grid_map[40][10] = 32800 - expected_grid_map[40][14] = 32800 - expected_grid_map[40][15] = 128 - expected_grid_map[40][17] = 32800 - expected_grid_map[40][18] = 8192 - expected_grid_map[40][19] = 32800 - expected_grid_map[40][33] = 32800 - expected_grid_map[40][39] = 32800 - expected_grid_map[40][40] = 32800 - expected_grid_map[40][42] = 32872 - expected_grid_map[40][43] = 4608 - expected_grid_map[40][44] = 32800 - expected_grid_map[40][47] = 32800 - expected_grid_map[40][48] = 32800 - expected_grid_map[41][10] = 32800 - expected_grid_map[41][14] = 32800 - expected_grid_map[41][17] = 32800 - expected_grid_map[41][18] = 72 - expected_grid_map[41][19] = 37408 - expected_grid_map[41][21] = 8192 - expected_grid_map[41][33] = 32800 - expected_grid_map[41][39] = 32800 - expected_grid_map[41][40] = 128 - expected_grid_map[41][42] = 32800 - expected_grid_map[41][43] = 128 - expected_grid_map[41][44] = 32800 - expected_grid_map[41][47] = 32800 - expected_grid_map[41][48] = 32800 - expected_grid_map[42][10] = 32800 - expected_grid_map[42][14] = 72 - expected_grid_map[42][15] = 1025 - expected_grid_map[42][16] = 1025 - expected_grid_map[42][17] = 33825 - expected_grid_map[42][18] = 17411 - expected_grid_map[42][19] = 52275 - expected_grid_map[42][20] = 5633 - expected_grid_map[42][21] = 3089 - expected_grid_map[42][22] = 1025 - expected_grid_map[42][23] = 1025 - expected_grid_map[42][24] = 1025 + expected_grid_map[38][49] = 32800 + expected_grid_map[39][6] = 32800 + expected_grid_map[39][7] = 32800 + expected_grid_map[39][8] = 32800 + expected_grid_map[39][12] = 49186 + expected_grid_map[39][13] = 1025 + expected_grid_map[39][14] = 1025 + expected_grid_map[39][15] = 52275 + expected_grid_map[39][16] = 1025 + expected_grid_map[39][17] = 1025 + expected_grid_map[39][18] = 1025 + expected_grid_map[39][19] = 1025 + expected_grid_map[39][20] = 1025 + expected_grid_map[39][21] = 1025 + expected_grid_map[39][22] = 1025 + expected_grid_map[39][23] = 52275 + expected_grid_map[39][24] = 1025 + expected_grid_map[39][25] = 1025 + expected_grid_map[39][26] = 1025 + expected_grid_map[39][27] = 37408 + expected_grid_map[39][39] = 32800 + expected_grid_map[39][49] = 32800 + expected_grid_map[40][6] = 32872 + expected_grid_map[40][7] = 38505 + expected_grid_map[40][8] = 37408 + expected_grid_map[40][12] = 49186 + expected_grid_map[40][13] = 1025 + expected_grid_map[40][14] = 17411 + expected_grid_map[40][15] = 52275 + expected_grid_map[40][16] = 1025 + expected_grid_map[40][17] = 1025 + expected_grid_map[40][18] = 1025 + expected_grid_map[40][19] = 1025 + expected_grid_map[40][20] = 1025 + expected_grid_map[40][21] = 1025 + expected_grid_map[40][22] = 1025 + expected_grid_map[40][23] = 38505 + expected_grid_map[40][24] = 1025 + expected_grid_map[40][25] = 1025 + expected_grid_map[40][26] = 1025 + expected_grid_map[40][27] = 1097 + expected_grid_map[40][28] = 17411 + expected_grid_map[40][29] = 1025 + expected_grid_map[40][30] = 1025 + expected_grid_map[40][31] = 1025 + expected_grid_map[40][32] = 1025 + expected_grid_map[40][33] = 1025 + expected_grid_map[40][34] = 1025 + expected_grid_map[40][35] = 1025 + expected_grid_map[40][36] = 17411 + expected_grid_map[40][37] = 1025 + expected_grid_map[40][38] = 1025 + expected_grid_map[40][39] = 1097 + expected_grid_map[40][40] = 5633 + expected_grid_map[40][41] = 1025 + expected_grid_map[40][42] = 1025 + expected_grid_map[40][43] = 1025 + expected_grid_map[40][44] = 1025 + expected_grid_map[40][45] = 1025 + expected_grid_map[40][46] = 1025 + expected_grid_map[40][47] = 1025 + expected_grid_map[40][48] = 5633 + expected_grid_map[40][49] = 34864 + expected_grid_map[41][6] = 72 + expected_grid_map[41][7] = 1097 + expected_grid_map[41][8] = 1097 + expected_grid_map[41][9] = 1025 + expected_grid_map[41][10] = 1025 + expected_grid_map[41][11] = 1025 + expected_grid_map[41][12] = 3089 + expected_grid_map[41][13] = 1025 + expected_grid_map[41][14] = 3089 + expected_grid_map[41][15] = 3089 + expected_grid_map[41][16] = 1025 + expected_grid_map[41][17] = 1025 + expected_grid_map[41][18] = 1025 + expected_grid_map[41][19] = 1025 + expected_grid_map[41][20] = 1025 + expected_grid_map[41][21] = 1025 + expected_grid_map[41][22] = 1025 + expected_grid_map[41][23] = 3089 + expected_grid_map[41][24] = 5633 + expected_grid_map[41][25] = 1025 + expected_grid_map[41][26] = 1025 + expected_grid_map[41][27] = 1025 + expected_grid_map[41][28] = 38505 + expected_grid_map[41][29] = 1025 + expected_grid_map[41][30] = 1025 + expected_grid_map[41][31] = 1025 + expected_grid_map[41][32] = 1025 + expected_grid_map[41][33] = 1025 + expected_grid_map[41][34] = 1025 + expected_grid_map[41][35] = 1025 + expected_grid_map[41][36] = 52275 + expected_grid_map[41][37] = 1025 + expected_grid_map[41][38] = 1025 + expected_grid_map[41][39] = 1025 + expected_grid_map[41][40] = 38505 + expected_grid_map[41][41] = 1025 + expected_grid_map[41][42] = 1025 + expected_grid_map[41][43] = 1025 + expected_grid_map[41][44] = 1025 + expected_grid_map[41][45] = 1025 + expected_grid_map[41][46] = 1025 + expected_grid_map[41][47] = 1025 + expected_grid_map[41][48] = 52275 + expected_grid_map[41][49] = 34864 + expected_grid_map[42][24] = 72 expected_grid_map[42][25] = 1025 expected_grid_map[42][26] = 1025 expected_grid_map[42][27] = 1025 - expected_grid_map[42][28] = 1025 + expected_grid_map[42][28] = 1097 expected_grid_map[42][29] = 1025 - expected_grid_map[42][30] = 4608 - expected_grid_map[42][33] = 32800 - expected_grid_map[42][39] = 32800 - expected_grid_map[42][42] = 32800 - expected_grid_map[42][44] = 32800 - expected_grid_map[42][47] = 32800 - expected_grid_map[42][48] = 32800 - expected_grid_map[43][10] = 32800 - expected_grid_map[43][17] = 32800 - expected_grid_map[43][18] = 128 - expected_grid_map[43][19] = 32800 - expected_grid_map[43][20] = 32800 - expected_grid_map[43][30] = 32800 - expected_grid_map[43][33] = 32800 - expected_grid_map[43][39] = 32800 - expected_grid_map[43][42] = 32800 - expected_grid_map[43][44] = 32800 - expected_grid_map[43][47] = 32800 - expected_grid_map[43][48] = 32800 - expected_grid_map[44][4] = 4 - expected_grid_map[44][5] = 1025 - expected_grid_map[44][6] = 1025 - expected_grid_map[44][7] = 1025 - expected_grid_map[44][8] = 1025 - expected_grid_map[44][9] = 1025 - expected_grid_map[44][10] = 3089 - expected_grid_map[44][11] = 1025 - expected_grid_map[44][12] = 1025 - expected_grid_map[44][13] = 1025 - expected_grid_map[44][14] = 1025 - expected_grid_map[44][15] = 1025 - expected_grid_map[44][16] = 1025 - expected_grid_map[44][17] = 3089 - expected_grid_map[44][18] = 1025 - expected_grid_map[44][19] = 2064 - expected_grid_map[44][20] = 128 - expected_grid_map[44][30] = 72 - expected_grid_map[44][31] = 1025 - expected_grid_map[44][32] = 1025 - expected_grid_map[44][33] = 35889 - expected_grid_map[44][34] = 1025 - expected_grid_map[44][35] = 1025 - expected_grid_map[44][36] = 1025 - expected_grid_map[44][37] = 1025 - expected_grid_map[44][38] = 1025 - expected_grid_map[44][39] = 33825 - expected_grid_map[44][40] = 1025 - expected_grid_map[44][41] = 1025 - expected_grid_map[44][42] = 2064 - expected_grid_map[44][44] = 32800 - expected_grid_map[44][47] = 32800 - expected_grid_map[44][48] = 32800 - expected_grid_map[45][33] = 32872 - expected_grid_map[45][34] = 1025 - expected_grid_map[45][35] = 1025 - expected_grid_map[45][36] = 1025 - expected_grid_map[45][37] = 1025 - expected_grid_map[45][38] = 1025 - expected_grid_map[45][39] = 33825 - expected_grid_map[45][40] = 1025 - expected_grid_map[45][41] = 1025 - expected_grid_map[45][42] = 1025 - expected_grid_map[45][43] = 1025 - expected_grid_map[45][44] = 1097 - expected_grid_map[45][45] = 1025 - expected_grid_map[45][46] = 1025 - expected_grid_map[45][47] = 34864 - expected_grid_map[45][48] = 32800 - expected_grid_map[46][33] = 32800 - expected_grid_map[46][39] = 32800 - expected_grid_map[46][47] = 32800 - expected_grid_map[46][48] = 32800 - expected_grid_map[47][33] = 32800 - expected_grid_map[47][39] = 32800 - expected_grid_map[47][47] = 32800 - expected_grid_map[47][48] = 128 - expected_grid_map[48][33] = 32800 - expected_grid_map[48][39] = 32800 - expected_grid_map[48][47] = 32800 - expected_grid_map[49][33] = 72 - expected_grid_map[49][34] = 1025 - expected_grid_map[49][35] = 1025 - expected_grid_map[49][36] = 1025 - expected_grid_map[49][37] = 1025 - expected_grid_map[49][38] = 1025 - expected_grid_map[49][39] = 2136 - expected_grid_map[49][40] = 1025 - expected_grid_map[49][41] = 1025 - expected_grid_map[49][42] = 1025 - expected_grid_map[49][43] = 1025 - expected_grid_map[49][44] = 1025 - expected_grid_map[49][45] = 1025 - expected_grid_map[49][46] = 1025 - expected_grid_map[49][47] = 2064 - + expected_grid_map[42][30] = 1025 + expected_grid_map[42][31] = 1025 + expected_grid_map[42][32] = 1025 + expected_grid_map[42][33] = 1025 + expected_grid_map[42][34] = 1025 + expected_grid_map[42][35] = 1025 + expected_grid_map[42][36] = 1097 + expected_grid_map[42][37] = 1025 + expected_grid_map[42][38] = 1025 + expected_grid_map[42][39] = 1025 + expected_grid_map[42][40] = 3089 + expected_grid_map[42][41] = 1025 + expected_grid_map[42][42] = 1025 + expected_grid_map[42][43] = 1025 + expected_grid_map[42][44] = 1025 + expected_grid_map[42][45] = 1025 + expected_grid_map[42][46] = 1025 + expected_grid_map[42][47] = 1025 + expected_grid_map[42][48] = 1097 + expected_grid_map[42][49] = 2064 assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid, expected_grid_map) s0 = 0 s1 = 0 for a in range(env.get_num_agents()): - s0 = Vec2d.get_manhattan_distance(env.agents[a].position, (0, 0)) - s1 = Vec2d.get_chebyshev_distance(env.agents[a].position, (0, 0)) + s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0)) + s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0)) assert s0 == 53, "actual={}".format(s0) - assert s1 == 36, "actual={}".format(s1) + assert s1 == 44, "actual={}".format(s1) def test_sparse_rail_generator_deterministic(): @@ -734,23 +617,18 @@ def test_sparse_rail_generator_deterministic(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections + rail_generator=sparse_rail_generator(max_num_cities=5, + max_rails_between_cities=3, seed=215545, # Random seed - grid_mode=True, - enhance_intersection=False + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) + # for r in range(env.height): + # for c in range(env.width): + # print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r,c,env.rail.get_full_transitions(r,c),r,c)) assert env.rail.get_full_transitions(0, 0) == 0, "[0][0]" assert env.rail.get_full_transitions(0, 1) == 0, "[0][1]" assert env.rail.get_full_transitions(0, 2) == 0, "[0][2]" @@ -761,16 +639,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(0, 7) == 0, "[0][7]" assert env.rail.get_full_transitions(0, 8) == 0, "[0][8]" assert env.rail.get_full_transitions(0, 9) == 0, "[0][9]" - assert env.rail.get_full_transitions(0, 10) == 0, "[0][10]" - assert env.rail.get_full_transitions(0, 11) == 0, "[0][11]" - assert env.rail.get_full_transitions(0, 12) == 0, "[0][12]" - assert env.rail.get_full_transitions(0, 13) == 0, "[0][13]" - assert env.rail.get_full_transitions(0, 14) == 0, "[0][14]" - assert env.rail.get_full_transitions(0, 15) == 0, "[0][15]" - assert env.rail.get_full_transitions(0, 16) == 0, "[0][16]" - assert env.rail.get_full_transitions(0, 17) == 0, "[0][17]" - assert env.rail.get_full_transitions(0, 18) == 0, "[0][18]" - assert env.rail.get_full_transitions(0, 19) == 0, "[0][19]" + assert env.rail.get_full_transitions(0, 10) == 16386, "[0][10]" + assert env.rail.get_full_transitions(0, 11) == 1025, "[0][11]" + assert env.rail.get_full_transitions(0, 12) == 1025, "[0][12]" + assert env.rail.get_full_transitions(0, 13) == 1025, "[0][13]" + assert env.rail.get_full_transitions(0, 14) == 17411, "[0][14]" + assert env.rail.get_full_transitions(0, 15) == 1025, "[0][15]" + assert env.rail.get_full_transitions(0, 16) == 1025, "[0][16]" + assert env.rail.get_full_transitions(0, 17) == 1025, "[0][17]" + assert env.rail.get_full_transitions(0, 18) == 5633, "[0][18]" + assert env.rail.get_full_transitions(0, 19) == 4608, "[0][19]" assert env.rail.get_full_transitions(0, 20) == 0, "[0][20]" assert env.rail.get_full_transitions(0, 21) == 0, "[0][21]" assert env.rail.get_full_transitions(0, 22) == 0, "[0][22]" @@ -786,17 +664,17 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(1, 7) == 0, "[1][7]" assert env.rail.get_full_transitions(1, 8) == 0, "[1][8]" assert env.rail.get_full_transitions(1, 9) == 0, "[1][9]" - assert env.rail.get_full_transitions(1, 10) == 0, "[1][10]" + assert env.rail.get_full_transitions(1, 10) == 32800, "[1][10]" assert env.rail.get_full_transitions(1, 11) == 0, "[1][11]" assert env.rail.get_full_transitions(1, 12) == 0, "[1][12]" assert env.rail.get_full_transitions(1, 13) == 0, "[1][13]" - assert env.rail.get_full_transitions(1, 14) == 0, "[1][14]" + assert env.rail.get_full_transitions(1, 14) == 32800, "[1][14]" assert env.rail.get_full_transitions(1, 15) == 0, "[1][15]" assert env.rail.get_full_transitions(1, 16) == 0, "[1][16]" - assert env.rail.get_full_transitions(1, 17) == 0, "[1][17]" - assert env.rail.get_full_transitions(1, 18) == 0, "[1][18]" - assert env.rail.get_full_transitions(1, 19) == 0, "[1][19]" - assert env.rail.get_full_transitions(1, 20) == 0, "[1][20]" + assert env.rail.get_full_transitions(1, 17) == 16386, "[1][17]" + assert env.rail.get_full_transitions(1, 18) == 38505, "[1][18]" + assert env.rail.get_full_transitions(1, 19) == 52275, "[1][19]" + assert env.rail.get_full_transitions(1, 20) == 4608, "[1][20]" assert env.rail.get_full_transitions(1, 21) == 0, "[1][21]" assert env.rail.get_full_transitions(1, 22) == 0, "[1][22]" assert env.rail.get_full_transitions(1, 23) == 0, "[1][23]" @@ -811,17 +689,17 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(2, 7) == 0, "[2][7]" assert env.rail.get_full_transitions(2, 8) == 0, "[2][8]" assert env.rail.get_full_transitions(2, 9) == 0, "[2][9]" - assert env.rail.get_full_transitions(2, 10) == 0, "[2][10]" + assert env.rail.get_full_transitions(2, 10) == 32800, "[2][10]" assert env.rail.get_full_transitions(2, 11) == 0, "[2][11]" assert env.rail.get_full_transitions(2, 12) == 0, "[2][12]" assert env.rail.get_full_transitions(2, 13) == 0, "[2][13]" - assert env.rail.get_full_transitions(2, 14) == 0, "[2][14]" + assert env.rail.get_full_transitions(2, 14) == 32800, "[2][14]" assert env.rail.get_full_transitions(2, 15) == 0, "[2][15]" assert env.rail.get_full_transitions(2, 16) == 0, "[2][16]" - assert env.rail.get_full_transitions(2, 17) == 0, "[2][17]" - assert env.rail.get_full_transitions(2, 18) == 0, "[2][18]" - assert env.rail.get_full_transitions(2, 19) == 0, "[2][19]" - assert env.rail.get_full_transitions(2, 20) == 0, "[2][20]" + assert env.rail.get_full_transitions(2, 17) == 32800, "[2][17]" + assert env.rail.get_full_transitions(2, 18) == 32800, "[2][18]" + assert env.rail.get_full_transitions(2, 19) == 32800, "[2][19]" + assert env.rail.get_full_transitions(2, 20) == 32800, "[2][20]" assert env.rail.get_full_transitions(2, 21) == 0, "[2][21]" assert env.rail.get_full_transitions(2, 22) == 0, "[2][22]" assert env.rail.get_full_transitions(2, 23) == 0, "[2][23]" @@ -829,107 +707,107 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(3, 0) == 0, "[3][0]" assert env.rail.get_full_transitions(3, 1) == 0, "[3][1]" assert env.rail.get_full_transitions(3, 2) == 0, "[3][2]" - assert env.rail.get_full_transitions(3, 3) == 16386, "[3][3]" - assert env.rail.get_full_transitions(3, 4) == 1025, "[3][4]" - assert env.rail.get_full_transitions(3, 5) == 1025, "[3][5]" - assert env.rail.get_full_transitions(3, 6) == 1025, "[3][6]" - assert env.rail.get_full_transitions(3, 7) == 1025, "[3][7]" - assert env.rail.get_full_transitions(3, 8) == 1025, "[3][8]" - assert env.rail.get_full_transitions(3, 9) == 1025, "[3][9]" - assert env.rail.get_full_transitions(3, 10) == 1025, "[3][10]" - assert env.rail.get_full_transitions(3, 11) == 1025, "[3][11]" - assert env.rail.get_full_transitions(3, 12) == 4608, "[3][12]" + assert env.rail.get_full_transitions(3, 3) == 0, "[3][3]" + assert env.rail.get_full_transitions(3, 4) == 0, "[3][4]" + assert env.rail.get_full_transitions(3, 5) == 0, "[3][5]" + assert env.rail.get_full_transitions(3, 6) == 0, "[3][6]" + assert env.rail.get_full_transitions(3, 7) == 0, "[3][7]" + assert env.rail.get_full_transitions(3, 8) == 0, "[3][8]" + assert env.rail.get_full_transitions(3, 9) == 0, "[3][9]" + assert env.rail.get_full_transitions(3, 10) == 32800, "[3][10]" + assert env.rail.get_full_transitions(3, 11) == 0, "[3][11]" + assert env.rail.get_full_transitions(3, 12) == 0, "[3][12]" assert env.rail.get_full_transitions(3, 13) == 0, "[3][13]" - assert env.rail.get_full_transitions(3, 14) == 0, "[3][14]" + assert env.rail.get_full_transitions(3, 14) == 32800, "[3][14]" assert env.rail.get_full_transitions(3, 15) == 0, "[3][15]" assert env.rail.get_full_transitions(3, 16) == 0, "[3][16]" - assert env.rail.get_full_transitions(3, 17) == 0, "[3][17]" - assert env.rail.get_full_transitions(3, 18) == 0, "[3][18]" - assert env.rail.get_full_transitions(3, 19) == 0, "[3][19]" - assert env.rail.get_full_transitions(3, 20) == 0, "[3][20]" + assert env.rail.get_full_transitions(3, 17) == 32800, "[3][17]" + assert env.rail.get_full_transitions(3, 18) == 32800, "[3][18]" + assert env.rail.get_full_transitions(3, 19) == 32800, "[3][19]" + assert env.rail.get_full_transitions(3, 20) == 32800, "[3][20]" assert env.rail.get_full_transitions(3, 21) == 0, "[3][21]" - assert env.rail.get_full_transitions(3, 22) == 8192, "[3][22]" + assert env.rail.get_full_transitions(3, 22) == 0, "[3][22]" assert env.rail.get_full_transitions(3, 23) == 0, "[3][23]" assert env.rail.get_full_transitions(3, 24) == 0, "[3][24]" assert env.rail.get_full_transitions(4, 0) == 0, "[4][0]" - assert env.rail.get_full_transitions(4, 1) == 0, "[4][1]" - assert env.rail.get_full_transitions(4, 2) == 0, "[4][2]" - assert env.rail.get_full_transitions(4, 3) == 32800, "[4][3]" - assert env.rail.get_full_transitions(4, 4) == 0, "[4][4]" - assert env.rail.get_full_transitions(4, 5) == 0, "[4][5]" - assert env.rail.get_full_transitions(4, 6) == 0, "[4][6]" - assert env.rail.get_full_transitions(4, 7) == 0, "[4][7]" - assert env.rail.get_full_transitions(4, 8) == 0, "[4][8]" - assert env.rail.get_full_transitions(4, 9) == 0, "[4][9]" - assert env.rail.get_full_transitions(4, 10) == 0, "[4][10]" + assert env.rail.get_full_transitions(4, 1) == 16386, "[4][1]" + assert env.rail.get_full_transitions(4, 2) == 1025, "[4][2]" + assert env.rail.get_full_transitions(4, 3) == 1025, "[4][3]" + assert env.rail.get_full_transitions(4, 4) == 1025, "[4][4]" + assert env.rail.get_full_transitions(4, 5) == 1025, "[4][5]" + assert env.rail.get_full_transitions(4, 6) == 1025, "[4][6]" + assert env.rail.get_full_transitions(4, 7) == 1025, "[4][7]" + assert env.rail.get_full_transitions(4, 8) == 1025, "[4][8]" + assert env.rail.get_full_transitions(4, 9) == 4608, "[4][9]" + assert env.rail.get_full_transitions(4, 10) == 32800, "[4][10]" assert env.rail.get_full_transitions(4, 11) == 0, "[4][11]" - assert env.rail.get_full_transitions(4, 12) == 32800, "[4][12]" + assert env.rail.get_full_transitions(4, 12) == 0, "[4][12]" assert env.rail.get_full_transitions(4, 13) == 0, "[4][13]" - assert env.rail.get_full_transitions(4, 14) == 0, "[4][14]" + assert env.rail.get_full_transitions(4, 14) == 32800, "[4][14]" assert env.rail.get_full_transitions(4, 15) == 0, "[4][15]" assert env.rail.get_full_transitions(4, 16) == 0, "[4][16]" - assert env.rail.get_full_transitions(4, 17) == 0, "[4][17]" - assert env.rail.get_full_transitions(4, 18) == 0, "[4][18]" - assert env.rail.get_full_transitions(4, 19) == 0, "[4][19]" - assert env.rail.get_full_transitions(4, 20) == 0, "[4][20]" + assert env.rail.get_full_transitions(4, 17) == 32800, "[4][17]" + assert env.rail.get_full_transitions(4, 18) == 32800, "[4][18]" + assert env.rail.get_full_transitions(4, 19) == 32800, "[4][19]" + assert env.rail.get_full_transitions(4, 20) == 32800, "[4][20]" assert env.rail.get_full_transitions(4, 21) == 0, "[4][21]" - assert env.rail.get_full_transitions(4, 22) == 32800, "[4][22]" + assert env.rail.get_full_transitions(4, 22) == 0, "[4][22]" assert env.rail.get_full_transitions(4, 23) == 0, "[4][23]" assert env.rail.get_full_transitions(4, 24) == 0, "[4][24]" - assert env.rail.get_full_transitions(5, 0) == 0, "[5][0]" - assert env.rail.get_full_transitions(5, 1) == 0, "[5][1]" - assert env.rail.get_full_transitions(5, 2) == 0, "[5][2]" - assert env.rail.get_full_transitions(5, 3) == 32800, "[5][3]" - assert env.rail.get_full_transitions(5, 4) == 0, "[5][4]" - assert env.rail.get_full_transitions(5, 5) == 0, "[5][5]" - assert env.rail.get_full_transitions(5, 6) == 0, "[5][6]" - assert env.rail.get_full_transitions(5, 7) == 0, "[5][7]" - assert env.rail.get_full_transitions(5, 8) == 0, "[5][8]" - assert env.rail.get_full_transitions(5, 9) == 0, "[5][9]" - assert env.rail.get_full_transitions(5, 10) == 0, "[5][10]" - assert env.rail.get_full_transitions(5, 11) == 0, "[5][11]" - assert env.rail.get_full_transitions(5, 12) == 32800, "[5][12]" - assert env.rail.get_full_transitions(5, 13) == 0, "[5][13]" - assert env.rail.get_full_transitions(5, 14) == 0, "[5][14]" + assert env.rail.get_full_transitions(5, 0) == 16386, "[5][0]" + assert env.rail.get_full_transitions(5, 1) == 52275, "[5][1]" + assert env.rail.get_full_transitions(5, 2) == 1025, "[5][2]" + assert env.rail.get_full_transitions(5, 3) == 1025, "[5][3]" + assert env.rail.get_full_transitions(5, 4) == 1025, "[5][4]" + assert env.rail.get_full_transitions(5, 5) == 1025, "[5][5]" + assert env.rail.get_full_transitions(5, 6) == 1025, "[5][6]" + assert env.rail.get_full_transitions(5, 7) == 1025, "[5][7]" + assert env.rail.get_full_transitions(5, 8) == 1025, "[5][8]" + assert env.rail.get_full_transitions(5, 9) == 52275, "[5][9]" + assert env.rail.get_full_transitions(5, 10) == 3089, "[5][10]" + assert env.rail.get_full_transitions(5, 11) == 1025, "[5][11]" + assert env.rail.get_full_transitions(5, 12) == 1025, "[5][12]" + assert env.rail.get_full_transitions(5, 13) == 1025, "[5][13]" + assert env.rail.get_full_transitions(5, 14) == 2064, "[5][14]" assert env.rail.get_full_transitions(5, 15) == 0, "[5][15]" assert env.rail.get_full_transitions(5, 16) == 0, "[5][16]" - assert env.rail.get_full_transitions(5, 17) == 0, "[5][17]" - assert env.rail.get_full_transitions(5, 18) == 0, "[5][18]" - assert env.rail.get_full_transitions(5, 19) == 0, "[5][19]" - assert env.rail.get_full_transitions(5, 20) == 0, "[5][20]" + assert env.rail.get_full_transitions(5, 17) == 32800, "[5][17]" + assert env.rail.get_full_transitions(5, 18) == 32800, "[5][18]" + assert env.rail.get_full_transitions(5, 19) == 32800, "[5][19]" + assert env.rail.get_full_transitions(5, 20) == 32800, "[5][20]" assert env.rail.get_full_transitions(5, 21) == 0, "[5][21]" - assert env.rail.get_full_transitions(5, 22) == 32800, "[5][22]" + assert env.rail.get_full_transitions(5, 22) == 0, "[5][22]" assert env.rail.get_full_transitions(5, 23) == 0, "[5][23]" assert env.rail.get_full_transitions(5, 24) == 0, "[5][24]" - assert env.rail.get_full_transitions(6, 0) == 0, "[6][0]" - assert env.rail.get_full_transitions(6, 1) == 0, "[6][1]" - assert env.rail.get_full_transitions(6, 2) == 0, "[6][2]" - assert env.rail.get_full_transitions(6, 3) == 32800, "[6][3]" - assert env.rail.get_full_transitions(6, 4) == 0, "[6][4]" - assert env.rail.get_full_transitions(6, 5) == 0, "[6][5]" - assert env.rail.get_full_transitions(6, 6) == 0, "[6][6]" - assert env.rail.get_full_transitions(6, 7) == 0, "[6][7]" - assert env.rail.get_full_transitions(6, 8) == 0, "[6][8]" - assert env.rail.get_full_transitions(6, 9) == 0, "[6][9]" + assert env.rail.get_full_transitions(6, 0) == 49186, "[6][0]" + assert env.rail.get_full_transitions(6, 1) == 3089, "[6][1]" + assert env.rail.get_full_transitions(6, 2) == 1025, "[6][2]" + assert env.rail.get_full_transitions(6, 3) == 1025, "[6][3]" + assert env.rail.get_full_transitions(6, 4) == 1025, "[6][4]" + assert env.rail.get_full_transitions(6, 5) == 1025, "[6][5]" + assert env.rail.get_full_transitions(6, 6) == 1025, "[6][6]" + assert env.rail.get_full_transitions(6, 7) == 1025, "[6][7]" + assert env.rail.get_full_transitions(6, 8) == 1025, "[6][8]" + assert env.rail.get_full_transitions(6, 9) == 2064, "[6][9]" assert env.rail.get_full_transitions(6, 10) == 0, "[6][10]" assert env.rail.get_full_transitions(6, 11) == 0, "[6][11]" - assert env.rail.get_full_transitions(6, 12) == 32800, "[6][12]" + assert env.rail.get_full_transitions(6, 12) == 0, "[6][12]" assert env.rail.get_full_transitions(6, 13) == 0, "[6][13]" assert env.rail.get_full_transitions(6, 14) == 0, "[6][14]" assert env.rail.get_full_transitions(6, 15) == 0, "[6][15]" assert env.rail.get_full_transitions(6, 16) == 0, "[6][16]" - assert env.rail.get_full_transitions(6, 17) == 0, "[6][17]" - assert env.rail.get_full_transitions(6, 18) == 0, "[6][18]" - assert env.rail.get_full_transitions(6, 19) == 0, "[6][19]" - assert env.rail.get_full_transitions(6, 20) == 0, "[6][20]" + assert env.rail.get_full_transitions(6, 17) == 32800, "[6][17]" + assert env.rail.get_full_transitions(6, 18) == 32800, "[6][18]" + assert env.rail.get_full_transitions(6, 19) == 32800, "[6][19]" + assert env.rail.get_full_transitions(6, 20) == 32800, "[6][20]" assert env.rail.get_full_transitions(6, 21) == 0, "[6][21]" - assert env.rail.get_full_transitions(6, 22) == 32800, "[6][22]" + assert env.rail.get_full_transitions(6, 22) == 0, "[6][22]" assert env.rail.get_full_transitions(6, 23) == 0, "[6][23]" assert env.rail.get_full_transitions(6, 24) == 0, "[6][24]" - assert env.rail.get_full_transitions(7, 0) == 0, "[7][0]" + assert env.rail.get_full_transitions(7, 0) == 32800, "[7][0]" assert env.rail.get_full_transitions(7, 1) == 0, "[7][1]" assert env.rail.get_full_transitions(7, 2) == 0, "[7][2]" - assert env.rail.get_full_transitions(7, 3) == 32800, "[7][3]" + assert env.rail.get_full_transitions(7, 3) == 0, "[7][3]" assert env.rail.get_full_transitions(7, 4) == 0, "[7][4]" assert env.rail.get_full_transitions(7, 5) == 0, "[7][5]" assert env.rail.get_full_transitions(7, 6) == 0, "[7][6]" @@ -938,84 +816,84 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(7, 9) == 0, "[7][9]" assert env.rail.get_full_transitions(7, 10) == 0, "[7][10]" assert env.rail.get_full_transitions(7, 11) == 0, "[7][11]" - assert env.rail.get_full_transitions(7, 12) == 32800, "[7][12]" + assert env.rail.get_full_transitions(7, 12) == 0, "[7][12]" assert env.rail.get_full_transitions(7, 13) == 0, "[7][13]" assert env.rail.get_full_transitions(7, 14) == 0, "[7][14]" assert env.rail.get_full_transitions(7, 15) == 0, "[7][15]" assert env.rail.get_full_transitions(7, 16) == 0, "[7][16]" - assert env.rail.get_full_transitions(7, 17) == 0, "[7][17]" - assert env.rail.get_full_transitions(7, 18) == 0, "[7][18]" - assert env.rail.get_full_transitions(7, 19) == 0, "[7][19]" - assert env.rail.get_full_transitions(7, 20) == 0, "[7][20]" + assert env.rail.get_full_transitions(7, 17) == 32800, "[7][17]" + assert env.rail.get_full_transitions(7, 18) == 32800, "[7][18]" + assert env.rail.get_full_transitions(7, 19) == 32800, "[7][19]" + assert env.rail.get_full_transitions(7, 20) == 32800, "[7][20]" assert env.rail.get_full_transitions(7, 21) == 0, "[7][21]" - assert env.rail.get_full_transitions(7, 22) == 32800, "[7][22]" + assert env.rail.get_full_transitions(7, 22) == 0, "[7][22]" assert env.rail.get_full_transitions(7, 23) == 0, "[7][23]" assert env.rail.get_full_transitions(7, 24) == 0, "[7][24]" - assert env.rail.get_full_transitions(8, 0) == 0, "[8][0]" + assert env.rail.get_full_transitions(8, 0) == 32800, "[8][0]" assert env.rail.get_full_transitions(8, 1) == 0, "[8][1]" assert env.rail.get_full_transitions(8, 2) == 0, "[8][2]" - assert env.rail.get_full_transitions(8, 3) == 32800, "[8][3]" + assert env.rail.get_full_transitions(8, 3) == 0, "[8][3]" assert env.rail.get_full_transitions(8, 4) == 0, "[8][4]" - assert env.rail.get_full_transitions(8, 5) == 8192, "[8][5]" + assert env.rail.get_full_transitions(8, 5) == 0, "[8][5]" assert env.rail.get_full_transitions(8, 6) == 0, "[8][6]" assert env.rail.get_full_transitions(8, 7) == 0, "[8][7]" assert env.rail.get_full_transitions(8, 8) == 0, "[8][8]" - assert env.rail.get_full_transitions(8, 9) == 8192, "[8][9]" - assert env.rail.get_full_transitions(8, 10) == 8192, "[8][10]" + assert env.rail.get_full_transitions(8, 9) == 0, "[8][9]" + assert env.rail.get_full_transitions(8, 10) == 0, "[8][10]" assert env.rail.get_full_transitions(8, 11) == 0, "[8][11]" - assert env.rail.get_full_transitions(8, 12) == 32800, "[8][12]" - assert env.rail.get_full_transitions(8, 13) == 8192, "[8][13]" + assert env.rail.get_full_transitions(8, 12) == 0, "[8][12]" + assert env.rail.get_full_transitions(8, 13) == 0, "[8][13]" assert env.rail.get_full_transitions(8, 14) == 0, "[8][14]" assert env.rail.get_full_transitions(8, 15) == 0, "[8][15]" assert env.rail.get_full_transitions(8, 16) == 0, "[8][16]" - assert env.rail.get_full_transitions(8, 17) == 0, "[8][17]" - assert env.rail.get_full_transitions(8, 18) == 0, "[8][18]" - assert env.rail.get_full_transitions(8, 19) == 0, "[8][19]" - assert env.rail.get_full_transitions(8, 20) == 0, "[8][20]" + assert env.rail.get_full_transitions(8, 17) == 32800, "[8][17]" + assert env.rail.get_full_transitions(8, 18) == 32800, "[8][18]" + assert env.rail.get_full_transitions(8, 19) == 32800, "[8][19]" + assert env.rail.get_full_transitions(8, 20) == 32800, "[8][20]" assert env.rail.get_full_transitions(8, 21) == 0, "[8][21]" - assert env.rail.get_full_transitions(8, 22) == 32800, "[8][22]" + assert env.rail.get_full_transitions(8, 22) == 0, "[8][22]" assert env.rail.get_full_transitions(8, 23) == 0, "[8][23]" assert env.rail.get_full_transitions(8, 24) == 0, "[8][24]" - assert env.rail.get_full_transitions(9, 0) == 8192, "[9][0]" + assert env.rail.get_full_transitions(9, 0) == 32800, "[9][0]" assert env.rail.get_full_transitions(9, 1) == 0, "[9][1]" assert env.rail.get_full_transitions(9, 2) == 0, "[9][2]" - assert env.rail.get_full_transitions(9, 3) == 32800, "[9][3]" - assert env.rail.get_full_transitions(9, 4) == 8192, "[9][4]" - assert env.rail.get_full_transitions(9, 5) == 32800, "[9][5]" + assert env.rail.get_full_transitions(9, 3) == 0, "[9][3]" + assert env.rail.get_full_transitions(9, 4) == 0, "[9][4]" + assert env.rail.get_full_transitions(9, 5) == 0, "[9][5]" assert env.rail.get_full_transitions(9, 6) == 0, "[9][6]" assert env.rail.get_full_transitions(9, 7) == 0, "[9][7]" assert env.rail.get_full_transitions(9, 8) == 0, "[9][8]" - assert env.rail.get_full_transitions(9, 9) == 72, "[9][9]" - assert env.rail.get_full_transitions(9, 10) == 37408, "[9][10]" + assert env.rail.get_full_transitions(9, 9) == 0, "[9][9]" + assert env.rail.get_full_transitions(9, 10) == 0, "[9][10]" assert env.rail.get_full_transitions(9, 11) == 0, "[9][11]" - assert env.rail.get_full_transitions(9, 12) == 49186, "[9][12]" - assert env.rail.get_full_transitions(9, 13) == 3089, "[9][13]" - assert env.rail.get_full_transitions(9, 14) == 4608, "[9][14]" + assert env.rail.get_full_transitions(9, 12) == 0, "[9][12]" + assert env.rail.get_full_transitions(9, 13) == 0, "[9][13]" + assert env.rail.get_full_transitions(9, 14) == 0, "[9][14]" assert env.rail.get_full_transitions(9, 15) == 0, "[9][15]" assert env.rail.get_full_transitions(9, 16) == 0, "[9][16]" - assert env.rail.get_full_transitions(9, 17) == 0, "[9][17]" - assert env.rail.get_full_transitions(9, 18) == 0, "[9][18]" - assert env.rail.get_full_transitions(9, 19) == 0, "[9][19]" - assert env.rail.get_full_transitions(9, 20) == 0, "[9][20]" + assert env.rail.get_full_transitions(9, 17) == 72, "[9][17]" + assert env.rail.get_full_transitions(9, 18) == 3089, "[9][18]" + assert env.rail.get_full_transitions(9, 19) == 1097, "[9][19]" + assert env.rail.get_full_transitions(9, 20) == 2064, "[9][20]" assert env.rail.get_full_transitions(9, 21) == 0, "[9][21]" - assert env.rail.get_full_transitions(9, 22) == 32800, "[9][22]" + assert env.rail.get_full_transitions(9, 22) == 0, "[9][22]" assert env.rail.get_full_transitions(9, 23) == 0, "[9][23]" assert env.rail.get_full_transitions(9, 24) == 0, "[9][24]" assert env.rail.get_full_transitions(10, 0) == 32800, "[10][0]" assert env.rail.get_full_transitions(10, 1) == 0, "[10][1]" assert env.rail.get_full_transitions(10, 2) == 0, "[10][2]" - assert env.rail.get_full_transitions(10, 3) == 32800, "[10][3]" - assert env.rail.get_full_transitions(10, 4) == 32800, "[10][4]" - assert env.rail.get_full_transitions(10, 5) == 32800, "[10][5]" + assert env.rail.get_full_transitions(10, 3) == 0, "[10][3]" + assert env.rail.get_full_transitions(10, 4) == 0, "[10][4]" + assert env.rail.get_full_transitions(10, 5) == 0, "[10][5]" assert env.rail.get_full_transitions(10, 6) == 0, "[10][6]" assert env.rail.get_full_transitions(10, 7) == 0, "[10][7]" assert env.rail.get_full_transitions(10, 8) == 0, "[10][8]" - assert env.rail.get_full_transitions(10, 9) == 4, "[10][9]" - assert env.rail.get_full_transitions(10, 10) == 1097, "[10][10]" - assert env.rail.get_full_transitions(10, 11) == 1025, "[10][11]" - assert env.rail.get_full_transitions(10, 12) == 37408, "[10][12]" + assert env.rail.get_full_transitions(10, 9) == 0, "[10][9]" + assert env.rail.get_full_transitions(10, 10) == 0, "[10][10]" + assert env.rail.get_full_transitions(10, 11) == 0, "[10][11]" + assert env.rail.get_full_transitions(10, 12) == 0, "[10][12]" assert env.rail.get_full_transitions(10, 13) == 0, "[10][13]" - assert env.rail.get_full_transitions(10, 14) == 128, "[10][14]" + assert env.rail.get_full_transitions(10, 14) == 0, "[10][14]" assert env.rail.get_full_transitions(10, 15) == 0, "[10][15]" assert env.rail.get_full_transitions(10, 16) == 0, "[10][16]" assert env.rail.get_full_transitions(10, 17) == 0, "[10][17]" @@ -1023,22 +901,22 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(10, 19) == 0, "[10][19]" assert env.rail.get_full_transitions(10, 20) == 0, "[10][20]" assert env.rail.get_full_transitions(10, 21) == 0, "[10][21]" - assert env.rail.get_full_transitions(10, 22) == 32800, "[10][22]" + assert env.rail.get_full_transitions(10, 22) == 0, "[10][22]" assert env.rail.get_full_transitions(10, 23) == 0, "[10][23]" assert env.rail.get_full_transitions(10, 24) == 0, "[10][24]" - assert env.rail.get_full_transitions(11, 0) == 16458, "[11][0]" - assert env.rail.get_full_transitions(11, 1) == 17411, "[11][1]" - assert env.rail.get_full_transitions(11, 2) == 1025, "[11][2]" - assert env.rail.get_full_transitions(11, 3) == 52275, "[11][3]" - assert env.rail.get_full_transitions(11, 4) == 3089, "[11][4]" - assert env.rail.get_full_transitions(11, 5) == 2064, "[11][5]" + assert env.rail.get_full_transitions(11, 0) == 32800, "[11][0]" + assert env.rail.get_full_transitions(11, 1) == 0, "[11][1]" + assert env.rail.get_full_transitions(11, 2) == 0, "[11][2]" + assert env.rail.get_full_transitions(11, 3) == 0, "[11][3]" + assert env.rail.get_full_transitions(11, 4) == 0, "[11][4]" + assert env.rail.get_full_transitions(11, 5) == 0, "[11][5]" assert env.rail.get_full_transitions(11, 6) == 0, "[11][6]" assert env.rail.get_full_transitions(11, 7) == 0, "[11][7]" assert env.rail.get_full_transitions(11, 8) == 0, "[11][8]" assert env.rail.get_full_transitions(11, 9) == 0, "[11][9]" assert env.rail.get_full_transitions(11, 10) == 0, "[11][10]" assert env.rail.get_full_transitions(11, 11) == 0, "[11][11]" - assert env.rail.get_full_transitions(11, 12) == 32800, "[11][12]" + assert env.rail.get_full_transitions(11, 12) == 0, "[11][12]" assert env.rail.get_full_transitions(11, 13) == 0, "[11][13]" assert env.rail.get_full_transitions(11, 14) == 0, "[11][14]" assert env.rail.get_full_transitions(11, 15) == 0, "[11][15]" @@ -1048,124 +926,124 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(11, 19) == 0, "[11][19]" assert env.rail.get_full_transitions(11, 20) == 0, "[11][20]" assert env.rail.get_full_transitions(11, 21) == 0, "[11][21]" - assert env.rail.get_full_transitions(11, 22) == 32800, "[11][22]" + assert env.rail.get_full_transitions(11, 22) == 0, "[11][22]" assert env.rail.get_full_transitions(11, 23) == 0, "[11][23]" assert env.rail.get_full_transitions(11, 24) == 0, "[11][24]" - assert env.rail.get_full_transitions(12, 0) == 128, "[12][0]" - assert env.rail.get_full_transitions(12, 1) == 128, "[12][1]" - assert env.rail.get_full_transitions(12, 2) == 0, "[12][2]" - assert env.rail.get_full_transitions(12, 3) == 49186, "[12][3]" + assert env.rail.get_full_transitions(12, 0) == 32800, "[12][0]" + assert env.rail.get_full_transitions(12, 1) == 16386, "[12][1]" + assert env.rail.get_full_transitions(12, 2) == 1025, "[12][2]" + assert env.rail.get_full_transitions(12, 3) == 1025, "[12][3]" assert env.rail.get_full_transitions(12, 4) == 1025, "[12][4]" assert env.rail.get_full_transitions(12, 5) == 1025, "[12][5]" assert env.rail.get_full_transitions(12, 6) == 1025, "[12][6]" assert env.rail.get_full_transitions(12, 7) == 1025, "[12][7]" assert env.rail.get_full_transitions(12, 8) == 1025, "[12][8]" - assert env.rail.get_full_transitions(12, 9) == 1025, "[12][9]" - assert env.rail.get_full_transitions(12, 10) == 1025, "[12][10]" + assert env.rail.get_full_transitions(12, 9) == 4608, "[12][9]" + assert env.rail.get_full_transitions(12, 10) == 16386, "[12][10]" assert env.rail.get_full_transitions(12, 11) == 1025, "[12][11]" - assert env.rail.get_full_transitions(12, 12) == 34864, "[12][12]" - assert env.rail.get_full_transitions(12, 13) == 0, "[12][13]" - assert env.rail.get_full_transitions(12, 14) == 0, "[12][14]" - assert env.rail.get_full_transitions(12, 15) == 0, "[12][15]" - assert env.rail.get_full_transitions(12, 16) == 0, "[12][16]" - assert env.rail.get_full_transitions(12, 17) == 0, "[12][17]" - assert env.rail.get_full_transitions(12, 18) == 0, "[12][18]" - assert env.rail.get_full_transitions(12, 19) == 0, "[12][19]" - assert env.rail.get_full_transitions(12, 20) == 0, "[12][20]" - assert env.rail.get_full_transitions(12, 21) == 0, "[12][21]" - assert env.rail.get_full_transitions(12, 22) == 32800, "[12][22]" - assert env.rail.get_full_transitions(12, 23) == 0, "[12][23]" + assert env.rail.get_full_transitions(12, 12) == 1025, "[12][12]" + assert env.rail.get_full_transitions(12, 13) == 1025, "[12][13]" + assert env.rail.get_full_transitions(12, 14) == 1025, "[12][14]" + assert env.rail.get_full_transitions(12, 15) == 5633, "[12][15]" + assert env.rail.get_full_transitions(12, 16) == 1025, "[12][16]" + assert env.rail.get_full_transitions(12, 17) == 1025, "[12][17]" + assert env.rail.get_full_transitions(12, 18) == 1025, "[12][18]" + assert env.rail.get_full_transitions(12, 19) == 1025, "[12][19]" + assert env.rail.get_full_transitions(12, 20) == 1025, "[12][20]" + assert env.rail.get_full_transitions(12, 21) == 1025, "[12][21]" + assert env.rail.get_full_transitions(12, 22) == 1025, "[12][22]" + assert env.rail.get_full_transitions(12, 23) == 4608, "[12][23]" assert env.rail.get_full_transitions(12, 24) == 0, "[12][24]" - assert env.rail.get_full_transitions(13, 0) == 0, "[13][0]" - assert env.rail.get_full_transitions(13, 1) == 0, "[13][1]" - assert env.rail.get_full_transitions(13, 2) == 0, "[13][2]" - assert env.rail.get_full_transitions(13, 3) == 32800, "[13][3]" - assert env.rail.get_full_transitions(13, 4) == 0, "[13][4]" - assert env.rail.get_full_transitions(13, 5) == 0, "[13][5]" - assert env.rail.get_full_transitions(13, 6) == 0, "[13][6]" - assert env.rail.get_full_transitions(13, 7) == 0, "[13][7]" - assert env.rail.get_full_transitions(13, 8) == 0, "[13][8]" - assert env.rail.get_full_transitions(13, 9) == 0, "[13][9]" - assert env.rail.get_full_transitions(13, 10) == 0, "[13][10]" - assert env.rail.get_full_transitions(13, 11) == 0, "[13][11]" - assert env.rail.get_full_transitions(13, 12) == 32800, "[13][12]" - assert env.rail.get_full_transitions(13, 13) == 0, "[13][13]" - assert env.rail.get_full_transitions(13, 14) == 0, "[13][14]" - assert env.rail.get_full_transitions(13, 15) == 0, "[13][15]" - assert env.rail.get_full_transitions(13, 16) == 0, "[13][16]" - assert env.rail.get_full_transitions(13, 17) == 0, "[13][17]" - assert env.rail.get_full_transitions(13, 18) == 0, "[13][18]" - assert env.rail.get_full_transitions(13, 19) == 0, "[13][19]" - assert env.rail.get_full_transitions(13, 20) == 0, "[13][20]" - assert env.rail.get_full_transitions(13, 21) == 0, "[13][21]" - assert env.rail.get_full_transitions(13, 22) == 32800, "[13][22]" - assert env.rail.get_full_transitions(13, 23) == 0, "[13][23]" + assert env.rail.get_full_transitions(13, 0) == 16458, "[13][0]" + assert env.rail.get_full_transitions(13, 1) == 52275, "[13][1]" + assert env.rail.get_full_transitions(13, 2) == 1025, "[13][2]" + assert env.rail.get_full_transitions(13, 3) == 1025, "[13][3]" + assert env.rail.get_full_transitions(13, 4) == 1025, "[13][4]" + assert env.rail.get_full_transitions(13, 5) == 1025, "[13][5]" + assert env.rail.get_full_transitions(13, 6) == 1025, "[13][6]" + assert env.rail.get_full_transitions(13, 7) == 1025, "[13][7]" + assert env.rail.get_full_transitions(13, 8) == 1025, "[13][8]" + assert env.rail.get_full_transitions(13, 9) == 52275, "[13][9]" + assert env.rail.get_full_transitions(13, 10) == 3089, "[13][10]" + assert env.rail.get_full_transitions(13, 11) == 1025, "[13][11]" + assert env.rail.get_full_transitions(13, 12) == 1025, "[13][12]" + assert env.rail.get_full_transitions(13, 13) == 1025, "[13][13]" + assert env.rail.get_full_transitions(13, 14) == 1025, "[13][14]" + assert env.rail.get_full_transitions(13, 15) == 38505, "[13][15]" + assert env.rail.get_full_transitions(13, 16) == 1025, "[13][16]" + assert env.rail.get_full_transitions(13, 17) == 1025, "[13][17]" + assert env.rail.get_full_transitions(13, 18) == 1025, "[13][18]" + assert env.rail.get_full_transitions(13, 19) == 1025, "[13][19]" + assert env.rail.get_full_transitions(13, 20) == 1025, "[13][20]" + assert env.rail.get_full_transitions(13, 21) == 1025, "[13][21]" + assert env.rail.get_full_transitions(13, 22) == 1025, "[13][22]" + assert env.rail.get_full_transitions(13, 23) == 37408, "[13][23]" assert env.rail.get_full_transitions(13, 24) == 0, "[13][24]" - assert env.rail.get_full_transitions(14, 0) == 0, "[14][0]" - assert env.rail.get_full_transitions(14, 1) == 0, "[14][1]" - assert env.rail.get_full_transitions(14, 2) == 0, "[14][2]" - assert env.rail.get_full_transitions(14, 3) == 32800, "[14][3]" - assert env.rail.get_full_transitions(14, 4) == 0, "[14][4]" - assert env.rail.get_full_transitions(14, 5) == 0, "[14][5]" - assert env.rail.get_full_transitions(14, 6) == 0, "[14][6]" - assert env.rail.get_full_transitions(14, 7) == 0, "[14][7]" - assert env.rail.get_full_transitions(14, 8) == 0, "[14][8]" - assert env.rail.get_full_transitions(14, 9) == 0, "[14][9]" - assert env.rail.get_full_transitions(14, 10) == 0, "[14][10]" - assert env.rail.get_full_transitions(14, 11) == 0, "[14][11]" - assert env.rail.get_full_transitions(14, 12) == 32800, "[14][12]" - assert env.rail.get_full_transitions(14, 13) == 0, "[14][13]" - assert env.rail.get_full_transitions(14, 14) == 0, "[14][14]" - assert env.rail.get_full_transitions(14, 15) == 0, "[14][15]" - assert env.rail.get_full_transitions(14, 16) == 0, "[14][16]" - assert env.rail.get_full_transitions(14, 17) == 0, "[14][17]" - assert env.rail.get_full_transitions(14, 18) == 0, "[14][18]" - assert env.rail.get_full_transitions(14, 19) == 0, "[14][19]" - assert env.rail.get_full_transitions(14, 20) == 0, "[14][20]" - assert env.rail.get_full_transitions(14, 21) == 0, "[14][21]" - assert env.rail.get_full_transitions(14, 22) == 32800, "[14][22]" - assert env.rail.get_full_transitions(14, 23) == 0, "[14][23]" + assert env.rail.get_full_transitions(14, 0) == 49186, "[14][0]" + assert env.rail.get_full_transitions(14, 1) == 38505, "[14][1]" + assert env.rail.get_full_transitions(14, 2) == 1025, "[14][2]" + assert env.rail.get_full_transitions(14, 3) == 1025, "[14][3]" + assert env.rail.get_full_transitions(14, 4) == 1025, "[14][4]" + assert env.rail.get_full_transitions(14, 5) == 1025, "[14][5]" + assert env.rail.get_full_transitions(14, 6) == 1025, "[14][6]" + assert env.rail.get_full_transitions(14, 7) == 1025, "[14][7]" + assert env.rail.get_full_transitions(14, 8) == 1025, "[14][8]" + assert env.rail.get_full_transitions(14, 9) == 38505, "[14][9]" + assert env.rail.get_full_transitions(14, 10) == 5633, "[14][10]" + assert env.rail.get_full_transitions(14, 11) == 1025, "[14][11]" + assert env.rail.get_full_transitions(14, 12) == 1025, "[14][12]" + assert env.rail.get_full_transitions(14, 13) == 1025, "[14][13]" + assert env.rail.get_full_transitions(14, 14) == 1025, "[14][14]" + assert env.rail.get_full_transitions(14, 15) == 38505, "[14][15]" + assert env.rail.get_full_transitions(14, 16) == 1025, "[14][16]" + assert env.rail.get_full_transitions(14, 17) == 1025, "[14][17]" + assert env.rail.get_full_transitions(14, 18) == 1025, "[14][18]" + assert env.rail.get_full_transitions(14, 19) == 1025, "[14][19]" + assert env.rail.get_full_transitions(14, 20) == 1025, "[14][20]" + assert env.rail.get_full_transitions(14, 21) == 1025, "[14][21]" + assert env.rail.get_full_transitions(14, 22) == 1025, "[14][22]" + assert env.rail.get_full_transitions(14, 23) == 34864, "[14][23]" assert env.rail.get_full_transitions(14, 24) == 0, "[14][24]" - assert env.rail.get_full_transitions(15, 0) == 0, "[15][0]" - assert env.rail.get_full_transitions(15, 1) == 0, "[15][1]" - assert env.rail.get_full_transitions(15, 2) == 0, "[15][2]" - assert env.rail.get_full_transitions(15, 3) == 32800, "[15][3]" - assert env.rail.get_full_transitions(15, 4) == 0, "[15][4]" - assert env.rail.get_full_transitions(15, 5) == 0, "[15][5]" - assert env.rail.get_full_transitions(15, 6) == 0, "[15][6]" - assert env.rail.get_full_transitions(15, 7) == 0, "[15][7]" - assert env.rail.get_full_transitions(15, 8) == 0, "[15][8]" - assert env.rail.get_full_transitions(15, 9) == 0, "[15][9]" - assert env.rail.get_full_transitions(15, 10) == 0, "[15][10]" - assert env.rail.get_full_transitions(15, 11) == 0, "[15][11]" - assert env.rail.get_full_transitions(15, 12) == 32800, "[15][12]" - assert env.rail.get_full_transitions(15, 13) == 0, "[15][13]" - assert env.rail.get_full_transitions(15, 14) == 0, "[15][14]" - assert env.rail.get_full_transitions(15, 15) == 0, "[15][15]" - assert env.rail.get_full_transitions(15, 16) == 0, "[15][16]" - assert env.rail.get_full_transitions(15, 17) == 0, "[15][17]" - assert env.rail.get_full_transitions(15, 18) == 0, "[15][18]" - assert env.rail.get_full_transitions(15, 19) == 0, "[15][19]" - assert env.rail.get_full_transitions(15, 20) == 0, "[15][20]" - assert env.rail.get_full_transitions(15, 21) == 0, "[15][21]" - assert env.rail.get_full_transitions(15, 22) == 32800, "[15][22]" - assert env.rail.get_full_transitions(15, 23) == 0, "[15][23]" + assert env.rail.get_full_transitions(15, 0) == 32800, "[15][0]" + assert env.rail.get_full_transitions(15, 1) == 72, "[15][1]" + assert env.rail.get_full_transitions(15, 2) == 1025, "[15][2]" + assert env.rail.get_full_transitions(15, 3) == 1025, "[15][3]" + assert env.rail.get_full_transitions(15, 4) == 1025, "[15][4]" + assert env.rail.get_full_transitions(15, 5) == 1025, "[15][5]" + assert env.rail.get_full_transitions(15, 6) == 1025, "[15][6]" + assert env.rail.get_full_transitions(15, 7) == 1025, "[15][7]" + assert env.rail.get_full_transitions(15, 8) == 1025, "[15][8]" + assert env.rail.get_full_transitions(15, 9) == 2064, "[15][9]" + assert env.rail.get_full_transitions(15, 10) == 32872, "[15][10]" + assert env.rail.get_full_transitions(15, 11) == 1025, "[15][11]" + assert env.rail.get_full_transitions(15, 12) == 1025, "[15][12]" + assert env.rail.get_full_transitions(15, 13) == 1025, "[15][13]" + assert env.rail.get_full_transitions(15, 14) == 17411, "[15][14]" + assert env.rail.get_full_transitions(15, 15) == 1097, "[15][15]" + assert env.rail.get_full_transitions(15, 16) == 1025, "[15][16]" + assert env.rail.get_full_transitions(15, 17) == 1025, "[15][17]" + assert env.rail.get_full_transitions(15, 18) == 1025, "[15][18]" + assert env.rail.get_full_transitions(15, 19) == 1025, "[15][19]" + assert env.rail.get_full_transitions(15, 20) == 1025, "[15][20]" + assert env.rail.get_full_transitions(15, 21) == 1025, "[15][21]" + assert env.rail.get_full_transitions(15, 22) == 1025, "[15][22]" + assert env.rail.get_full_transitions(15, 23) == 2064, "[15][23]" assert env.rail.get_full_transitions(15, 24) == 0, "[15][24]" - assert env.rail.get_full_transitions(16, 0) == 0, "[16][0]" + assert env.rail.get_full_transitions(16, 0) == 32800, "[16][0]" assert env.rail.get_full_transitions(16, 1) == 0, "[16][1]" assert env.rail.get_full_transitions(16, 2) == 0, "[16][2]" - assert env.rail.get_full_transitions(16, 3) == 32800, "[16][3]" + assert env.rail.get_full_transitions(16, 3) == 0, "[16][3]" assert env.rail.get_full_transitions(16, 4) == 0, "[16][4]" assert env.rail.get_full_transitions(16, 5) == 0, "[16][5]" assert env.rail.get_full_transitions(16, 6) == 0, "[16][6]" assert env.rail.get_full_transitions(16, 7) == 0, "[16][7]" assert env.rail.get_full_transitions(16, 8) == 0, "[16][8]" assert env.rail.get_full_transitions(16, 9) == 0, "[16][9]" - assert env.rail.get_full_transitions(16, 10) == 0, "[16][10]" + assert env.rail.get_full_transitions(16, 10) == 32800, "[16][10]" assert env.rail.get_full_transitions(16, 11) == 0, "[16][11]" - assert env.rail.get_full_transitions(16, 12) == 32800, "[16][12]" + assert env.rail.get_full_transitions(16, 12) == 0, "[16][12]" assert env.rail.get_full_transitions(16, 13) == 0, "[16][13]" - assert env.rail.get_full_transitions(16, 14) == 0, "[16][14]" + assert env.rail.get_full_transitions(16, 14) == 32800, "[16][14]" assert env.rail.get_full_transitions(16, 15) == 0, "[16][15]" assert env.rail.get_full_transitions(16, 16) == 0, "[16][16]" assert env.rail.get_full_transitions(16, 17) == 0, "[16][17]" @@ -1173,24 +1051,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(16, 19) == 0, "[16][19]" assert env.rail.get_full_transitions(16, 20) == 0, "[16][20]" assert env.rail.get_full_transitions(16, 21) == 0, "[16][21]" - assert env.rail.get_full_transitions(16, 22) == 32800, "[16][22]" + assert env.rail.get_full_transitions(16, 22) == 0, "[16][22]" assert env.rail.get_full_transitions(16, 23) == 0, "[16][23]" assert env.rail.get_full_transitions(16, 24) == 0, "[16][24]" - assert env.rail.get_full_transitions(17, 0) == 0, "[17][0]" + assert env.rail.get_full_transitions(17, 0) == 32800, "[17][0]" assert env.rail.get_full_transitions(17, 1) == 0, "[17][1]" assert env.rail.get_full_transitions(17, 2) == 0, "[17][2]" - assert env.rail.get_full_transitions(17, 3) == 32800, "[17][3]" + assert env.rail.get_full_transitions(17, 3) == 0, "[17][3]" assert env.rail.get_full_transitions(17, 4) == 0, "[17][4]" assert env.rail.get_full_transitions(17, 5) == 0, "[17][5]" assert env.rail.get_full_transitions(17, 6) == 0, "[17][6]" assert env.rail.get_full_transitions(17, 7) == 0, "[17][7]" assert env.rail.get_full_transitions(17, 8) == 0, "[17][8]" assert env.rail.get_full_transitions(17, 9) == 0, "[17][9]" - assert env.rail.get_full_transitions(17, 10) == 0, "[17][10]" + assert env.rail.get_full_transitions(17, 10) == 32800, "[17][10]" assert env.rail.get_full_transitions(17, 11) == 0, "[17][11]" - assert env.rail.get_full_transitions(17, 12) == 32800, "[17][12]" + assert env.rail.get_full_transitions(17, 12) == 0, "[17][12]" assert env.rail.get_full_transitions(17, 13) == 0, "[17][13]" - assert env.rail.get_full_transitions(17, 14) == 0, "[17][14]" + assert env.rail.get_full_transitions(17, 14) == 32800, "[17][14]" assert env.rail.get_full_transitions(17, 15) == 0, "[17][15]" assert env.rail.get_full_transitions(17, 16) == 0, "[17][16]" assert env.rail.get_full_transitions(17, 17) == 0, "[17][17]" @@ -1198,24 +1076,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(17, 19) == 0, "[17][19]" assert env.rail.get_full_transitions(17, 20) == 0, "[17][20]" assert env.rail.get_full_transitions(17, 21) == 0, "[17][21]" - assert env.rail.get_full_transitions(17, 22) == 32800, "[17][22]" + assert env.rail.get_full_transitions(17, 22) == 0, "[17][22]" assert env.rail.get_full_transitions(17, 23) == 0, "[17][23]" assert env.rail.get_full_transitions(17, 24) == 0, "[17][24]" - assert env.rail.get_full_transitions(18, 0) == 0, "[18][0]" + assert env.rail.get_full_transitions(18, 0) == 32800, "[18][0]" assert env.rail.get_full_transitions(18, 1) == 0, "[18][1]" assert env.rail.get_full_transitions(18, 2) == 0, "[18][2]" - assert env.rail.get_full_transitions(18, 3) == 32800, "[18][3]" + assert env.rail.get_full_transitions(18, 3) == 0, "[18][3]" assert env.rail.get_full_transitions(18, 4) == 0, "[18][4]" assert env.rail.get_full_transitions(18, 5) == 0, "[18][5]" assert env.rail.get_full_transitions(18, 6) == 0, "[18][6]" assert env.rail.get_full_transitions(18, 7) == 0, "[18][7]" assert env.rail.get_full_transitions(18, 8) == 0, "[18][8]" assert env.rail.get_full_transitions(18, 9) == 0, "[18][9]" - assert env.rail.get_full_transitions(18, 10) == 0, "[18][10]" + assert env.rail.get_full_transitions(18, 10) == 32800, "[18][10]" assert env.rail.get_full_transitions(18, 11) == 0, "[18][11]" - assert env.rail.get_full_transitions(18, 12) == 32800, "[18][12]" + assert env.rail.get_full_transitions(18, 12) == 0, "[18][12]" assert env.rail.get_full_transitions(18, 13) == 0, "[18][13]" - assert env.rail.get_full_transitions(18, 14) == 0, "[18][14]" + assert env.rail.get_full_transitions(18, 14) == 32800, "[18][14]" assert env.rail.get_full_transitions(18, 15) == 0, "[18][15]" assert env.rail.get_full_transitions(18, 16) == 0, "[18][16]" assert env.rail.get_full_transitions(18, 17) == 0, "[18][17]" @@ -1223,24 +1101,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(18, 19) == 0, "[18][19]" assert env.rail.get_full_transitions(18, 20) == 0, "[18][20]" assert env.rail.get_full_transitions(18, 21) == 0, "[18][21]" - assert env.rail.get_full_transitions(18, 22) == 32800, "[18][22]" + assert env.rail.get_full_transitions(18, 22) == 0, "[18][22]" assert env.rail.get_full_transitions(18, 23) == 0, "[18][23]" assert env.rail.get_full_transitions(18, 24) == 0, "[18][24]" - assert env.rail.get_full_transitions(19, 0) == 0, "[19][0]" - assert env.rail.get_full_transitions(19, 1) == 0, "[19][1]" - assert env.rail.get_full_transitions(19, 2) == 0, "[19][2]" - assert env.rail.get_full_transitions(19, 3) == 32872, "[19][3]" - assert env.rail.get_full_transitions(19, 4) == 1025, "[19][4]" - assert env.rail.get_full_transitions(19, 5) == 1025, "[19][5]" - assert env.rail.get_full_transitions(19, 6) == 1025, "[19][6]" + assert env.rail.get_full_transitions(19, 0) == 72, "[19][0]" + assert env.rail.get_full_transitions(19, 1) == 1025, "[19][1]" + assert env.rail.get_full_transitions(19, 2) == 1025, "[19][2]" + assert env.rail.get_full_transitions(19, 3) == 5633, "[19][3]" + assert env.rail.get_full_transitions(19, 4) == 4608, "[19][4]" + assert env.rail.get_full_transitions(19, 5) == 16386, "[19][5]" + assert env.rail.get_full_transitions(19, 6) == 17411, "[19][6]" assert env.rail.get_full_transitions(19, 7) == 1025, "[19][7]" assert env.rail.get_full_transitions(19, 8) == 1025, "[19][8]" assert env.rail.get_full_transitions(19, 9) == 1025, "[19][9]" - assert env.rail.get_full_transitions(19, 10) == 1025, "[19][10]" - assert env.rail.get_full_transitions(19, 11) == 1025, "[19][11]" - assert env.rail.get_full_transitions(19, 12) == 6672, "[19][12]" + assert env.rail.get_full_transitions(19, 10) == 2064, "[19][10]" + assert env.rail.get_full_transitions(19, 11) == 0, "[19][11]" + assert env.rail.get_full_transitions(19, 12) == 0, "[19][12]" assert env.rail.get_full_transitions(19, 13) == 0, "[19][13]" - assert env.rail.get_full_transitions(19, 14) == 0, "[19][14]" + assert env.rail.get_full_transitions(19, 14) == 32800, "[19][14]" assert env.rail.get_full_transitions(19, 15) == 0, "[19][15]" assert env.rail.get_full_transitions(19, 16) == 0, "[19][16]" assert env.rail.get_full_transitions(19, 17) == 0, "[19][17]" @@ -1248,24 +1126,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(19, 19) == 0, "[19][19]" assert env.rail.get_full_transitions(19, 20) == 0, "[19][20]" assert env.rail.get_full_transitions(19, 21) == 0, "[19][21]" - assert env.rail.get_full_transitions(19, 22) == 32800, "[19][22]" + assert env.rail.get_full_transitions(19, 22) == 0, "[19][22]" assert env.rail.get_full_transitions(19, 23) == 0, "[19][23]" assert env.rail.get_full_transitions(19, 24) == 0, "[19][24]" assert env.rail.get_full_transitions(20, 0) == 0, "[20][0]" assert env.rail.get_full_transitions(20, 1) == 0, "[20][1]" assert env.rail.get_full_transitions(20, 2) == 0, "[20][2]" - assert env.rail.get_full_transitions(20, 3) == 32800, "[20][3]" - assert env.rail.get_full_transitions(20, 4) == 0, "[20][4]" - assert env.rail.get_full_transitions(20, 5) == 0, "[20][5]" - assert env.rail.get_full_transitions(20, 6) == 0, "[20][6]" + assert env.rail.get_full_transitions(20, 3) == 32872, "[20][3]" + assert env.rail.get_full_transitions(20, 4) == 52275, "[20][4]" + assert env.rail.get_full_transitions(20, 5) == 52275, "[20][5]" + assert env.rail.get_full_transitions(20, 6) == 37408, "[20][6]" assert env.rail.get_full_transitions(20, 7) == 0, "[20][7]" assert env.rail.get_full_transitions(20, 8) == 0, "[20][8]" assert env.rail.get_full_transitions(20, 9) == 0, "[20][9]" assert env.rail.get_full_transitions(20, 10) == 0, "[20][10]" assert env.rail.get_full_transitions(20, 11) == 0, "[20][11]" - assert env.rail.get_full_transitions(20, 12) == 32800, "[20][12]" + assert env.rail.get_full_transitions(20, 12) == 0, "[20][12]" assert env.rail.get_full_transitions(20, 13) == 0, "[20][13]" - assert env.rail.get_full_transitions(20, 14) == 0, "[20][14]" + assert env.rail.get_full_transitions(20, 14) == 32800, "[20][14]" assert env.rail.get_full_transitions(20, 15) == 0, "[20][15]" assert env.rail.get_full_transitions(20, 16) == 0, "[20][16]" assert env.rail.get_full_transitions(20, 17) == 0, "[20][17]" @@ -1273,24 +1151,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(20, 19) == 0, "[20][19]" assert env.rail.get_full_transitions(20, 20) == 0, "[20][20]" assert env.rail.get_full_transitions(20, 21) == 0, "[20][21]" - assert env.rail.get_full_transitions(20, 22) == 32800, "[20][22]" + assert env.rail.get_full_transitions(20, 22) == 0, "[20][22]" assert env.rail.get_full_transitions(20, 23) == 0, "[20][23]" assert env.rail.get_full_transitions(20, 24) == 0, "[20][24]" assert env.rail.get_full_transitions(21, 0) == 0, "[21][0]" assert env.rail.get_full_transitions(21, 1) == 0, "[21][1]" assert env.rail.get_full_transitions(21, 2) == 0, "[21][2]" assert env.rail.get_full_transitions(21, 3) == 32800, "[21][3]" - assert env.rail.get_full_transitions(21, 4) == 0, "[21][4]" - assert env.rail.get_full_transitions(21, 5) == 0, "[21][5]" - assert env.rail.get_full_transitions(21, 6) == 0, "[21][6]" + assert env.rail.get_full_transitions(21, 4) == 32800, "[21][4]" + assert env.rail.get_full_transitions(21, 5) == 32800, "[21][5]" + assert env.rail.get_full_transitions(21, 6) == 32800, "[21][6]" assert env.rail.get_full_transitions(21, 7) == 0, "[21][7]" assert env.rail.get_full_transitions(21, 8) == 0, "[21][8]" assert env.rail.get_full_transitions(21, 9) == 0, "[21][9]" assert env.rail.get_full_transitions(21, 10) == 0, "[21][10]" assert env.rail.get_full_transitions(21, 11) == 0, "[21][11]" - assert env.rail.get_full_transitions(21, 12) == 32800, "[21][12]" + assert env.rail.get_full_transitions(21, 12) == 0, "[21][12]" assert env.rail.get_full_transitions(21, 13) == 0, "[21][13]" - assert env.rail.get_full_transitions(21, 14) == 0, "[21][14]" + assert env.rail.get_full_transitions(21, 14) == 32800, "[21][14]" assert env.rail.get_full_transitions(21, 15) == 0, "[21][15]" assert env.rail.get_full_transitions(21, 16) == 0, "[21][16]" assert env.rail.get_full_transitions(21, 17) == 0, "[21][17]" @@ -1298,24 +1176,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(21, 19) == 0, "[21][19]" assert env.rail.get_full_transitions(21, 20) == 0, "[21][20]" assert env.rail.get_full_transitions(21, 21) == 0, "[21][21]" - assert env.rail.get_full_transitions(21, 22) == 32800, "[21][22]" + assert env.rail.get_full_transitions(21, 22) == 0, "[21][22]" assert env.rail.get_full_transitions(21, 23) == 0, "[21][23]" assert env.rail.get_full_transitions(21, 24) == 0, "[21][24]" assert env.rail.get_full_transitions(22, 0) == 0, "[22][0]" assert env.rail.get_full_transitions(22, 1) == 0, "[22][1]" assert env.rail.get_full_transitions(22, 2) == 0, "[22][2]" assert env.rail.get_full_transitions(22, 3) == 32800, "[22][3]" - assert env.rail.get_full_transitions(22, 4) == 0, "[22][4]" - assert env.rail.get_full_transitions(22, 5) == 0, "[22][5]" - assert env.rail.get_full_transitions(22, 6) == 0, "[22][6]" + assert env.rail.get_full_transitions(22, 4) == 32800, "[22][4]" + assert env.rail.get_full_transitions(22, 5) == 32800, "[22][5]" + assert env.rail.get_full_transitions(22, 6) == 32800, "[22][6]" assert env.rail.get_full_transitions(22, 7) == 0, "[22][7]" assert env.rail.get_full_transitions(22, 8) == 0, "[22][8]" assert env.rail.get_full_transitions(22, 9) == 0, "[22][9]" assert env.rail.get_full_transitions(22, 10) == 0, "[22][10]" assert env.rail.get_full_transitions(22, 11) == 0, "[22][11]" - assert env.rail.get_full_transitions(22, 12) == 32800, "[22][12]" + assert env.rail.get_full_transitions(22, 12) == 0, "[22][12]" assert env.rail.get_full_transitions(22, 13) == 0, "[22][13]" - assert env.rail.get_full_transitions(22, 14) == 0, "[22][14]" + assert env.rail.get_full_transitions(22, 14) == 32800, "[22][14]" assert env.rail.get_full_transitions(22, 15) == 0, "[22][15]" assert env.rail.get_full_transitions(22, 16) == 0, "[22][16]" assert env.rail.get_full_transitions(22, 17) == 0, "[22][17]" @@ -1323,24 +1201,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(22, 19) == 0, "[22][19]" assert env.rail.get_full_transitions(22, 20) == 0, "[22][20]" assert env.rail.get_full_transitions(22, 21) == 0, "[22][21]" - assert env.rail.get_full_transitions(22, 22) == 32800, "[22][22]" + assert env.rail.get_full_transitions(22, 22) == 0, "[22][22]" assert env.rail.get_full_transitions(22, 23) == 0, "[22][23]" assert env.rail.get_full_transitions(22, 24) == 0, "[22][24]" assert env.rail.get_full_transitions(23, 0) == 0, "[23][0]" assert env.rail.get_full_transitions(23, 1) == 0, "[23][1]" assert env.rail.get_full_transitions(23, 2) == 0, "[23][2]" assert env.rail.get_full_transitions(23, 3) == 32800, "[23][3]" - assert env.rail.get_full_transitions(23, 4) == 0, "[23][4]" - assert env.rail.get_full_transitions(23, 5) == 0, "[23][5]" - assert env.rail.get_full_transitions(23, 6) == 0, "[23][6]" + assert env.rail.get_full_transitions(23, 4) == 32800, "[23][4]" + assert env.rail.get_full_transitions(23, 5) == 32800, "[23][5]" + assert env.rail.get_full_transitions(23, 6) == 32800, "[23][6]" assert env.rail.get_full_transitions(23, 7) == 0, "[23][7]" assert env.rail.get_full_transitions(23, 8) == 0, "[23][8]" assert env.rail.get_full_transitions(23, 9) == 0, "[23][9]" assert env.rail.get_full_transitions(23, 10) == 0, "[23][10]" assert env.rail.get_full_transitions(23, 11) == 0, "[23][11]" - assert env.rail.get_full_transitions(23, 12) == 32800, "[23][12]" + assert env.rail.get_full_transitions(23, 12) == 0, "[23][12]" assert env.rail.get_full_transitions(23, 13) == 0, "[23][13]" - assert env.rail.get_full_transitions(23, 14) == 0, "[23][14]" + assert env.rail.get_full_transitions(23, 14) == 32800, "[23][14]" assert env.rail.get_full_transitions(23, 15) == 0, "[23][15]" assert env.rail.get_full_transitions(23, 16) == 0, "[23][16]" assert env.rail.get_full_transitions(23, 17) == 0, "[23][17]" @@ -1348,24 +1226,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(23, 19) == 0, "[23][19]" assert env.rail.get_full_transitions(23, 20) == 0, "[23][20]" assert env.rail.get_full_transitions(23, 21) == 0, "[23][21]" - assert env.rail.get_full_transitions(23, 22) == 32800, "[23][22]" + assert env.rail.get_full_transitions(23, 22) == 0, "[23][22]" assert env.rail.get_full_transitions(23, 23) == 0, "[23][23]" assert env.rail.get_full_transitions(23, 24) == 0, "[23][24]" assert env.rail.get_full_transitions(24, 0) == 0, "[24][0]" assert env.rail.get_full_transitions(24, 1) == 0, "[24][1]" assert env.rail.get_full_transitions(24, 2) == 0, "[24][2]" assert env.rail.get_full_transitions(24, 3) == 32800, "[24][3]" - assert env.rail.get_full_transitions(24, 4) == 0, "[24][4]" - assert env.rail.get_full_transitions(24, 5) == 0, "[24][5]" - assert env.rail.get_full_transitions(24, 6) == 0, "[24][6]" + assert env.rail.get_full_transitions(24, 4) == 32800, "[24][4]" + assert env.rail.get_full_transitions(24, 5) == 32800, "[24][5]" + assert env.rail.get_full_transitions(24, 6) == 32800, "[24][6]" assert env.rail.get_full_transitions(24, 7) == 0, "[24][7]" assert env.rail.get_full_transitions(24, 8) == 0, "[24][8]" - assert env.rail.get_full_transitions(24, 9) == 8192, "[24][9]" + assert env.rail.get_full_transitions(24, 9) == 0, "[24][9]" assert env.rail.get_full_transitions(24, 10) == 0, "[24][10]" assert env.rail.get_full_transitions(24, 11) == 0, "[24][11]" - assert env.rail.get_full_transitions(24, 12) == 32800, "[24][12]" + assert env.rail.get_full_transitions(24, 12) == 0, "[24][12]" assert env.rail.get_full_transitions(24, 13) == 0, "[24][13]" - assert env.rail.get_full_transitions(24, 14) == 0, "[24][14]" + assert env.rail.get_full_transitions(24, 14) == 32800, "[24][14]" assert env.rail.get_full_transitions(24, 15) == 0, "[24][15]" assert env.rail.get_full_transitions(24, 16) == 0, "[24][16]" assert env.rail.get_full_transitions(24, 17) == 0, "[24][17]" @@ -1373,24 +1251,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(24, 19) == 0, "[24][19]" assert env.rail.get_full_transitions(24, 20) == 0, "[24][20]" assert env.rail.get_full_transitions(24, 21) == 0, "[24][21]" - assert env.rail.get_full_transitions(24, 22) == 32800, "[24][22]" + assert env.rail.get_full_transitions(24, 22) == 0, "[24][22]" assert env.rail.get_full_transitions(24, 23) == 0, "[24][23]" assert env.rail.get_full_transitions(24, 24) == 0, "[24][24]" assert env.rail.get_full_transitions(25, 0) == 0, "[25][0]" assert env.rail.get_full_transitions(25, 1) == 0, "[25][1]" assert env.rail.get_full_transitions(25, 2) == 0, "[25][2]" assert env.rail.get_full_transitions(25, 3) == 32800, "[25][3]" - assert env.rail.get_full_transitions(25, 4) == 0, "[25][4]" - assert env.rail.get_full_transitions(25, 5) == 8192, "[25][5]" - assert env.rail.get_full_transitions(25, 6) == 0, "[25][6]" + assert env.rail.get_full_transitions(25, 4) == 32800, "[25][4]" + assert env.rail.get_full_transitions(25, 5) == 32800, "[25][5]" + assert env.rail.get_full_transitions(25, 6) == 32800, "[25][6]" assert env.rail.get_full_transitions(25, 7) == 0, "[25][7]" assert env.rail.get_full_transitions(25, 8) == 0, "[25][8]" - assert env.rail.get_full_transitions(25, 9) == 32800, "[25][9]" + assert env.rail.get_full_transitions(25, 9) == 0, "[25][9]" assert env.rail.get_full_transitions(25, 10) == 0, "[25][10]" - assert env.rail.get_full_transitions(25, 11) == 8192, "[25][11]" - assert env.rail.get_full_transitions(25, 12) == 32800, "[25][12]" + assert env.rail.get_full_transitions(25, 11) == 0, "[25][11]" + assert env.rail.get_full_transitions(25, 12) == 0, "[25][12]" assert env.rail.get_full_transitions(25, 13) == 0, "[25][13]" - assert env.rail.get_full_transitions(25, 14) == 0, "[25][14]" + assert env.rail.get_full_transitions(25, 14) == 32800, "[25][14]" assert env.rail.get_full_transitions(25, 15) == 0, "[25][15]" assert env.rail.get_full_transitions(25, 16) == 0, "[25][16]" assert env.rail.get_full_transitions(25, 17) == 0, "[25][17]" @@ -1398,24 +1276,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(25, 19) == 0, "[25][19]" assert env.rail.get_full_transitions(25, 20) == 0, "[25][20]" assert env.rail.get_full_transitions(25, 21) == 0, "[25][21]" - assert env.rail.get_full_transitions(25, 22) == 32800, "[25][22]" + assert env.rail.get_full_transitions(25, 22) == 0, "[25][22]" assert env.rail.get_full_transitions(25, 23) == 0, "[25][23]" assert env.rail.get_full_transitions(25, 24) == 0, "[25][24]" - assert env.rail.get_full_transitions(26, 0) == 8192, "[26][0]" - assert env.rail.get_full_transitions(26, 1) == 4, "[26][1]" - assert env.rail.get_full_transitions(26, 2) == 4608, "[26][2]" + assert env.rail.get_full_transitions(26, 0) == 0, "[26][0]" + assert env.rail.get_full_transitions(26, 1) == 0, "[26][1]" + assert env.rail.get_full_transitions(26, 2) == 0, "[26][2]" assert env.rail.get_full_transitions(26, 3) == 32800, "[26][3]" - assert env.rail.get_full_transitions(26, 4) == 0, "[26][4]" + assert env.rail.get_full_transitions(26, 4) == 32800, "[26][4]" assert env.rail.get_full_transitions(26, 5) == 32800, "[26][5]" - assert env.rail.get_full_transitions(26, 6) == 0, "[26][6]" + assert env.rail.get_full_transitions(26, 6) == 32800, "[26][6]" assert env.rail.get_full_transitions(26, 7) == 0, "[26][7]" assert env.rail.get_full_transitions(26, 8) == 0, "[26][8]" - assert env.rail.get_full_transitions(26, 9) == 32800, "[26][9]" + assert env.rail.get_full_transitions(26, 9) == 0, "[26][9]" assert env.rail.get_full_transitions(26, 10) == 0, "[26][10]" - assert env.rail.get_full_transitions(26, 11) == 32800, "[26][11]" - assert env.rail.get_full_transitions(26, 12) == 32800, "[26][12]" + assert env.rail.get_full_transitions(26, 11) == 0, "[26][11]" + assert env.rail.get_full_transitions(26, 12) == 0, "[26][12]" assert env.rail.get_full_transitions(26, 13) == 0, "[26][13]" - assert env.rail.get_full_transitions(26, 14) == 0, "[26][14]" + assert env.rail.get_full_transitions(26, 14) == 32800, "[26][14]" assert env.rail.get_full_transitions(26, 15) == 0, "[26][15]" assert env.rail.get_full_transitions(26, 16) == 0, "[26][16]" assert env.rail.get_full_transitions(26, 17) == 0, "[26][17]" @@ -1423,49 +1301,49 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(26, 19) == 0, "[26][19]" assert env.rail.get_full_transitions(26, 20) == 0, "[26][20]" assert env.rail.get_full_transitions(26, 21) == 0, "[26][21]" - assert env.rail.get_full_transitions(26, 22) == 32800, "[26][22]" + assert env.rail.get_full_transitions(26, 22) == 0, "[26][22]" assert env.rail.get_full_transitions(26, 23) == 0, "[26][23]" assert env.rail.get_full_transitions(26, 24) == 0, "[26][24]" - assert env.rail.get_full_transitions(27, 0) == 72, "[27][0]" - assert env.rail.get_full_transitions(27, 1) == 17411, "[27][1]" - assert env.rail.get_full_transitions(27, 2) == 1097, "[27][2]" - assert env.rail.get_full_transitions(27, 3) == 1097, "[27][3]" - assert env.rail.get_full_transitions(27, 4) == 5633, "[27][4]" - assert env.rail.get_full_transitions(27, 5) == 3089, "[27][5]" - assert env.rail.get_full_transitions(27, 6) == 1025, "[27][6]" - assert env.rail.get_full_transitions(27, 7) == 1025, "[27][7]" - assert env.rail.get_full_transitions(27, 8) == 1025, "[27][8]" - assert env.rail.get_full_transitions(27, 9) == 1097, "[27][9]" - assert env.rail.get_full_transitions(27, 10) == 17411, "[27][10]" - assert env.rail.get_full_transitions(27, 11) == 1097, "[27][11]" - assert env.rail.get_full_transitions(27, 12) == 1097, "[27][12]" - assert env.rail.get_full_transitions(27, 13) == 5633, "[27][13]" - assert env.rail.get_full_transitions(27, 14) == 1025, "[27][14]" - assert env.rail.get_full_transitions(27, 15) == 1025, "[27][15]" - assert env.rail.get_full_transitions(27, 16) == 1025, "[27][16]" - assert env.rail.get_full_transitions(27, 17) == 1025, "[27][17]" - assert env.rail.get_full_transitions(27, 18) == 1025, "[27][18]" - assert env.rail.get_full_transitions(27, 19) == 1025, "[27][19]" - assert env.rail.get_full_transitions(27, 20) == 1025, "[27][20]" - assert env.rail.get_full_transitions(27, 21) == 1025, "[27][21]" - assert env.rail.get_full_transitions(27, 22) == 2064, "[27][22]" + assert env.rail.get_full_transitions(27, 0) == 0, "[27][0]" + assert env.rail.get_full_transitions(27, 1) == 0, "[27][1]" + assert env.rail.get_full_transitions(27, 2) == 0, "[27][2]" + assert env.rail.get_full_transitions(27, 3) == 32800, "[27][3]" + assert env.rail.get_full_transitions(27, 4) == 32800, "[27][4]" + assert env.rail.get_full_transitions(27, 5) == 32800, "[27][5]" + assert env.rail.get_full_transitions(27, 6) == 32800, "[27][6]" + assert env.rail.get_full_transitions(27, 7) == 0, "[27][7]" + assert env.rail.get_full_transitions(27, 8) == 0, "[27][8]" + assert env.rail.get_full_transitions(27, 9) == 0, "[27][9]" + assert env.rail.get_full_transitions(27, 10) == 0, "[27][10]" + assert env.rail.get_full_transitions(27, 11) == 0, "[27][11]" + assert env.rail.get_full_transitions(27, 12) == 0, "[27][12]" + assert env.rail.get_full_transitions(27, 13) == 0, "[27][13]" + assert env.rail.get_full_transitions(27, 14) == 32800, "[27][14]" + assert env.rail.get_full_transitions(27, 15) == 0, "[27][15]" + assert env.rail.get_full_transitions(27, 16) == 0, "[27][16]" + assert env.rail.get_full_transitions(27, 17) == 0, "[27][17]" + assert env.rail.get_full_transitions(27, 18) == 0, "[27][18]" + assert env.rail.get_full_transitions(27, 19) == 0, "[27][19]" + assert env.rail.get_full_transitions(27, 20) == 0, "[27][20]" + assert env.rail.get_full_transitions(27, 21) == 0, "[27][21]" + assert env.rail.get_full_transitions(27, 22) == 0, "[27][22]" assert env.rail.get_full_transitions(27, 23) == 0, "[27][23]" assert env.rail.get_full_transitions(27, 24) == 0, "[27][24]" assert env.rail.get_full_transitions(28, 0) == 0, "[28][0]" - assert env.rail.get_full_transitions(28, 1) == 32800, "[28][1]" + assert env.rail.get_full_transitions(28, 1) == 0, "[28][1]" assert env.rail.get_full_transitions(28, 2) == 0, "[28][2]" - assert env.rail.get_full_transitions(28, 3) == 0, "[28][3]" - assert env.rail.get_full_transitions(28, 4) == 72, "[28][4]" - assert env.rail.get_full_transitions(28, 5) == 256, "[28][5]" - assert env.rail.get_full_transitions(28, 6) == 0, "[28][6]" + assert env.rail.get_full_transitions(28, 3) == 49186, "[28][3]" + assert env.rail.get_full_transitions(28, 4) == 52275, "[28][4]" + assert env.rail.get_full_transitions(28, 5) == 52275, "[28][5]" + assert env.rail.get_full_transitions(28, 6) == 34864, "[28][6]" assert env.rail.get_full_transitions(28, 7) == 0, "[28][7]" assert env.rail.get_full_transitions(28, 8) == 0, "[28][8]" assert env.rail.get_full_transitions(28, 9) == 0, "[28][9]" - assert env.rail.get_full_transitions(28, 10) == 32800, "[28][10]" + assert env.rail.get_full_transitions(28, 10) == 0, "[28][10]" assert env.rail.get_full_transitions(28, 11) == 0, "[28][11]" - assert env.rail.get_full_transitions(28, 12) == 16386, "[28][12]" - assert env.rail.get_full_transitions(28, 13) == 34864, "[28][13]" - assert env.rail.get_full_transitions(28, 14) == 0, "[28][14]" + assert env.rail.get_full_transitions(28, 12) == 0, "[28][12]" + assert env.rail.get_full_transitions(28, 13) == 0, "[28][13]" + assert env.rail.get_full_transitions(28, 14) == 32800, "[28][14]" assert env.rail.get_full_transitions(28, 15) == 0, "[28][15]" assert env.rail.get_full_transitions(28, 16) == 0, "[28][16]" assert env.rail.get_full_transitions(28, 17) == 0, "[28][17]" @@ -1477,20 +1355,20 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(28, 23) == 0, "[28][23]" assert env.rail.get_full_transitions(28, 24) == 0, "[28][24]" assert env.rail.get_full_transitions(29, 0) == 0, "[29][0]" - assert env.rail.get_full_transitions(29, 1) == 128, "[29][1]" + assert env.rail.get_full_transitions(29, 1) == 0, "[29][1]" assert env.rail.get_full_transitions(29, 2) == 0, "[29][2]" - assert env.rail.get_full_transitions(29, 3) == 0, "[29][3]" - assert env.rail.get_full_transitions(29, 4) == 0, "[29][4]" - assert env.rail.get_full_transitions(29, 5) == 0, "[29][5]" - assert env.rail.get_full_transitions(29, 6) == 0, "[29][6]" - assert env.rail.get_full_transitions(29, 7) == 0, "[29][7]" - assert env.rail.get_full_transitions(29, 8) == 0, "[29][8]" - assert env.rail.get_full_transitions(29, 9) == 0, "[29][9]" - assert env.rail.get_full_transitions(29, 10) == 128, "[29][10]" - assert env.rail.get_full_transitions(29, 11) == 0, "[29][11]" - assert env.rail.get_full_transitions(29, 12) == 128, "[29][12]" - assert env.rail.get_full_transitions(29, 13) == 128, "[29][13]" - assert env.rail.get_full_transitions(29, 14) == 0, "[29][14]" + assert env.rail.get_full_transitions(29, 3) == 72, "[29][3]" + assert env.rail.get_full_transitions(29, 4) == 1097, "[29][4]" + assert env.rail.get_full_transitions(29, 5) == 1097, "[29][5]" + assert env.rail.get_full_transitions(29, 6) == 1097, "[29][6]" + assert env.rail.get_full_transitions(29, 7) == 1025, "[29][7]" + assert env.rail.get_full_transitions(29, 8) == 1025, "[29][8]" + assert env.rail.get_full_transitions(29, 9) == 1025, "[29][9]" + assert env.rail.get_full_transitions(29, 10) == 1025, "[29][10]" + assert env.rail.get_full_transitions(29, 11) == 1025, "[29][11]" + assert env.rail.get_full_transitions(29, 12) == 1025, "[29][12]" + assert env.rail.get_full_transitions(29, 13) == 1025, "[29][13]" + assert env.rail.get_full_transitions(29, 14) == 2064, "[29][14]" assert env.rail.get_full_transitions(29, 15) == 0, "[29][15]" assert env.rail.get_full_transitions(29, 16) == 0, "[29][16]" assert env.rail.get_full_transitions(29, 17) == 0, "[29][17]" @@ -1505,46 +1383,33 @@ def test_sparse_rail_generator_deterministic(): def test_rail_env_action_required_info(): np.random.seed(0) + random.seed(0) speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train env_always_action = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes - ), + rail_generator=sparse_rail_generator( + max_num_cities=10, + max_rails_between_cities=3, + seed=5, # Random seed + grid_mode=False # Ordered distribution of nodes + ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) np.random.seed(0) + random.seed(0) env_only_if_action_required = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, - # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False - # Ordered distribution of nodes - ), + rail_generator=sparse_rail_generator( + max_num_cities=10, + max_rails_between_cities=3, + seed=5, # Random seed + grid_mode=False + # Ordered distribution of nodes + ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) @@ -1576,7 +1441,19 @@ def test_rail_env_action_required_info(): for a in range(env_always_action.get_num_agents()): assert len(obs_always_action[a]) == len(obs_only_if_action_required[a]) for i in range(len(obs_always_action[a])): - assert np.array_equal(obs_always_action[a][i], obs_only_if_action_required[a][i]) + assert len(obs_always_action[a][i]) == len(obs_only_if_action_required[a][i]) + equal = np.array_equal(obs_always_action[a][i], obs_only_if_action_required[a][i]) + if not equal: + for r in range(50): + for c in range(50): + assert np.array_equal(obs_always_action[a][i][(r, c)], obs_only_if_action_required[a][i][ + (r, c)]), "[{}] a={},i={},{}\n{}\n\nvs.\n\n{}".format(step, a, i, (r, c), + obs_always_action[a][i][(r, c)], + obs_only_if_action_required[a][ + i][(r, c)]) + assert equal, \ + "[{}] [{}][{}] {} vs. {}".format(step, a, i, obs_always_action[a][i], + obs_only_if_action_required[a][i]) assert np.array_equal(rewards_always_action[a], rewards_only_if_action_required[a]) assert np.array_equal(done_always_action[a], done_only_if_action_required[a]) assert info_always_action['action_required'][a] == info_only_if_action_required['action_required'][a] @@ -1597,18 +1474,10 @@ def test_rail_env_malfunction_speed_info(): } env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes + rail_generator=sparse_rail_generator(max_num_cities=10, + max_rails_between_cities=3, + seed=5, + grid_mode=False ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, @@ -1646,14 +1515,10 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down(): RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( - num_cities=100, # Number of cities in map - num_intersections=10, # Number of interesections in map - num_trainstations=50, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes + max_num_cities=100, + max_rails_between_cities=3, + seed=5, + grid_mode=False ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 8b4fc8bb64b753532b96f0158d3fa754bf855e2b..70bb1eb4912470245f05e4ef112386c44620b808 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -31,15 +31,15 @@ class SingleAgentNavigationObs(ObservationBuilder): agent = self.env.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART: - _agent_initial_position = agent.initial_position + agent_virtual_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: - _agent_initial_position = agent.position + agent_virtual_position = agent.position elif agent.status == RailAgentStatus.DONE: - _agent_initial_position = agent.target + agent_virtual_position = agent.target else: return None - possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction) + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Start from the current orientation, and see which transitions are available; @@ -51,14 +51,14 @@ class SingleAgentNavigationObs(ObservationBuilder): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = get_new_position(_agent_initial_position, direction) + new_position = get_new_position(agent_virtual_position, direction) min_distances.append( self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) observation = [0, 0, 0] - observation[np.argmin(min_distances)[0]] = 1 + observation[np.argmin(min_distances)] = 1 return observation @@ -158,7 +158,7 @@ def test_malfunction_process_statistically(): env.step(action_dict) # check that generation of malfunctions works as expected - assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction) + assert nb_malfunction == 128, "nb_malfunction={}".format(nb_malfunction) def test_initial_malfunction(): @@ -176,18 +176,10 @@ def test_initial_malfunction(): random.seed(0) env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed - grid_mode=True, - enhance_intersection=False + rail_generator=sparse_rail_generator(max_num_cities=5, + max_rails_between_cities=3, + seed=215545, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, @@ -222,14 +214,14 @@ def test_initial_malfunction(): # malfunctioning ends: starting and running at speed 1.0 ), Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, + position=(28, 6), + direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0 # running at speed 1.0 ), Replay( - position=(27, 4), + position=(27, 6), direction=Grid4TransitionsEnum.NORTH, action=RailEnvActions.MOVE_FORWARD, malfunction=0, @@ -259,18 +251,10 @@ def test_initial_malfunction_stop_moving(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed - grid_mode=True, - enhance_intersection=False + rail_generator=sparse_rail_generator(max_num_cities=5, + max_rails_between_cities=3, + seed=215545, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, @@ -326,8 +310,8 @@ def test_initial_malfunction_stop_moving(): status=RailAgentStatus.ACTIVE ), Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, + position=(28, 6), + direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # full step penalty while stopped @@ -359,18 +343,10 @@ def test_initial_malfunction_do_nothing(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed - grid_mode=True, - enhance_intersection=False + rail_generator=sparse_rail_generator(max_num_cities=5, + max_rails_between_cities=3, + seed=215545, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, @@ -426,8 +402,8 @@ def test_initial_malfunction_do_nothing(): status=RailAgentStatus.ACTIVE ), Replay( - position=(28, 4), - direction=Grid4TransitionsEnum.WEST, + position=(28, 6), + direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # step penalty for speed 1.0 @@ -460,18 +436,10 @@ def test_initial_nextmalfunction_not_below_zero(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed - grid_mode=True, - enhance_intersection=False + rail_generator=sparse_rail_generator(max_num_cities=5, + max_rails_between_cities=3, + seed=215545, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index f29629ab7c7aabb9ca2989b02a3604239d9e6143..401992e790b96df297c10f71bd152cbdef0edb9c 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -1,5 +1,6 @@ import numpy as np +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator @@ -23,42 +24,106 @@ def test_get_global_observation(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=25, - # Number of cities in map (where train stations are) - num_intersections=10, - # Number of intersections (no start / target) - num_trainstations=50, # Number of possible start/targets on map - min_node_dist=3, # Minimal distance of nodes - node_radius=4, # Proximity of stations to city center - num_neighb=4, - # Number of connections to other cities/intersections - seed=15, # Random seed - grid_mode=True, - enhance_intersection=False + rail_generator=sparse_rail_generator(max_num_cities=6, + max_rails_between_cities=4, + seed=15, + grid_mode=False ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) - for i in range(len(env.agents)): + agent: EnvAgent = env.agents[i] + print("[{}] status={}, position={}, target={}, initial_position={}".format(i, agent.status, agent.position, + agent.target, + agent.initial_position)) + + for i, agent in enumerate(env.agents): obs_agents_state = obs[i][1] obs_targets = obs[i][2] + # test first channel of obs_targets: own target nr_agents = np.count_nonzero(obs_targets[:, :, 0]) - nr_agents_other = np.count_nonzero(obs_targets[:, :, 1]) - assert nr_agents == 1 - assert nr_agents_other == (number_of_agents - 1) + assert nr_agents == 1, "agent {}: something wrong with own target, found {}".format(i, nr_agents) + + # test second channel of obs_targets: other agent's target + for r in range(env.height): + for c in range(env.width): + _other_agent_target = 0 + for other_i, other_agent in enumerate(env.agents): + if other_agent.target == (r, c): + _other_agent_target = 1 + break + assert obs_targets[(r, c)][ + 1] == _other_agent_target, "agent {}: at {} expected to be other agent's target = {}".format( + i, (r, c), + _other_agent_target) + + # test first channel of obs_agents_state: direction at own position + for r in range(env.height): + for c in range(env.width): + if (agent.status == RailAgentStatus.ACTIVE or agent.status == RailAgentStatus.DONE) and ( + r, c) == agent.position: + assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \ + "agent {} in status {} at {} expected to contain own direction {}, found {}" \ + .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0]) + elif (agent.status == RailAgentStatus.READY_TO_DEPART) and (r, c) == agent.initial_position: + assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \ + "agent {} in status {} at {} expected to contain own direction {}, found {}" \ + .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0]) + else: + assert np.isclose(obs_agents_state[(r, c)][0], -1), \ + "agent {} in status {} at {} expected contain -1 found {}" \ + .format(i, agent.status, (r, c), obs_agents_state[(r, c)][0]) + + # test second channel of obs_agents_state: direction at other agents position + for r in range(env.height): + for c in range(env.width): + has_agent = False + for other_i, other_agent in enumerate(env.agents): + if i == other_i: + continue + if other_agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and ( + r, c) == other_agent.position: + assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \ + "agent {} in status {} at {} should see other agent with direction {}, found = {}" \ + .format(i, agent.status, (r, c), other_agent.direction, obs_agents_state[(r, c)][1]) + has_agent = True + if not has_agent: + assert np.isclose(obs_agents_state[(r, c)][1], -1), \ + "agent {} in status {} at {} should see no other agent direction (-1), found = {}" \ + .format(i, agent.status, (r, c), obs_agents_state[(r, c)][1]) - # since the array is initialized with -1 add one in order to used np.count_nonzero - obs_agents_state += 1 - obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0]) - obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1]) - obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2]) - obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3]) - assert obs_agents_state_0 == 1 - assert obs_agents_state_1 == (number_of_agents - 1) - assert obs_agents_state_2 == number_of_agents - assert obs_agents_state_3 == number_of_agents + # test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid + for r in range(env.height): + for c in range(env.width): + has_agent = False + for other_i, other_agent in enumerate(env.agents): + if other_agent.status in [RailAgentStatus.ACTIVE, + RailAgentStatus.DONE] and other_agent.position == (r, c): + assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \ + "agent {} in status {} at {} should see agent malfunction {}, found = {}" \ + .format(i, agent.status, (r, c), other_agent.malfunction_data['malfunction'], + obs_agents_state[(r, c)][2]) + assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_data['speed']) + has_agent = True + if not has_agent: + assert np.isclose(obs_agents_state[(r, c)][2], -1), \ + "agent {} in status {} at {} should see no agent malfunction (-1), found = {}" \ + .format(i, agent.status, (r, c), obs_agents_state[(r, c)][2]) + assert np.isclose(obs_agents_state[(r, c)][3], -1), \ + "agent {} in status {} at {} should see no agent speed (-1), found = {}" \ + .format(i, agent.status, (r, c), obs_agents_state[(r, c)][3]) + # test fifth channel of obs_agents_state: number of agents ready to depart in to this cell + for r in range(env.height): + for c in range(env.width): + count = 0 + for other_i, other_agent in enumerate(env.agents): + if other_agent.status == RailAgentStatus.READY_TO_DEPART and other_agent.initial_position == (r, c): + count += 1 + assert np.isclose(obs_agents_state[(r, c)][4], count), \ + "agent {} in status {} at {} should see {} agents ready to depart, found{}" \ + .format(i, agent.status, (r, c), count, obs_agents_state[(r, c)][4])