diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79f5eb0d23b8b8a50bc1afd9ba1d34bba76c4ffd --- /dev/null +++ b/flatland/envs/env_utils.py @@ -0,0 +1,274 @@ + +""" +Definition of the RailEnv environment and related level-generation functions. + +Generator functions are functions that take width, height and num_resets as arguments and return +a GridTransitionMap object. +""" +# import numpy as np + +# from flatland.core.env import Environment +# from flatland.core.env_observation_builder import TreeObsForRailEnv + +# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions +# from flatland.core.transition_map import GridTransitionMap + + +class AStarNode(): + """A node class for A* Pathfinding""" + + def __init__(self, parent=None, pos=None): + self.parent = parent + self.pos = pos + self.g = 0 + self.h = 0 + self.f = 0 + + def __eq__(self, other): + return self.pos == other.pos + + def update_if_better(self, other): + if other.g < self.g: + self.parent = other.parent + self.g = other.g + self.h = other.h + self.f = other.f + + +def get_direction(pos1, pos2): + """ + Assumes pos1 and pos2 are adjacent location on grid. + Returns direction (int) that can be used with transitions. + """ + diff_0 = pos2[0] - pos1[0] + diff_1 = pos2[1] - pos1[1] + if diff_0 < 0: + return 0 + if diff_0 > 0: + return 2 + if diff_1 > 0: + return 1 + if diff_1 < 0: + return 3 + return 0 + + +def mirror(dir): + return (dir + 2) % 4 + + +def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): + # start by getting direction used to get to current node + # and direction from current node to possible child node + new_dir = get_direction(current_pos, new_pos) + if prev_pos is not None: + current_dir = get_direction(prev_pos, current_pos) + else: + current_dir = new_dir + # create new transition that would go to child + new_trans = rail_array[current_pos] + if prev_pos is None: + if new_trans == 0: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # check if matches existing layout + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + # rail_trans.print(new_trans) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # need to flip direction because of how end points are defined + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # check if matches existing layout + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) + # print("end:", end_pos, current_pos) + # rail_trans.print(new_trans_e) + + # print("========> end trans") + # rail_trans.print(new_trans_e) + if not rail_trans.is_valid(new_trans_e): + # print("end failed", end_pos, current_pos) + return False + # else: + # print("end ok!", end_pos, current_pos) + + # is transition is valid? + # print("=======> trans") + # rail_trans.print(new_trans) + return rail_trans.is_valid(new_trans) + + +def a_star(rail_trans, rail_array, start, end): + """ + Returns a list of tuples as a path from the given start to end. + If no path is found, returns path to closest point to end. + """ + rail_shape = rail_array.shape + start_node = AStarNode(None, start) + end_node = AStarNode(None, end) + open_list = [] + closed_list = [] + + open_list.append(start_node) + + # this could be optimized + def is_node_in_list(node, the_list): + for o_node in the_list: + if node == o_node: + return o_node + return None + + while len(open_list) > 0: + # get node with current shortest est. path (lowest f) + current_node = open_list[0] + current_index = 0 + for index, item in enumerate(open_list): + if item.f < current_node.f: + current_node = item + current_index = index + + # pop current off open list, add to closed list + open_list.pop(current_index) + closed_list.append(current_node) + + # print("a*:", current_node.pos) + # for cn in closed_list: + # print("closed:", cn.pos) + + # found the goal + if current_node == end_node: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + return path[::-1] + + # generate children + children = [] + if current_node.parent is not None: + prev_pos = current_node.parent.pos + else: + prev_pos = None + for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: + node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) + if node_pos[0] >= rail_shape[0] or \ + node_pos[0] < 0 or \ + node_pos[1] >= rail_shape[1] or \ + node_pos[1] < 0: + continue + + # validate positions + # debug: avoid all current rails + # if rail_array.item(node_pos) != 0: + # continue + + # validate positions + if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): + # print("A*: transition invalid") + continue + + # create new node + new_node = AStarNode(current_node, node_pos) + children.append(new_node) + + # loop through children + for child in children: + # already in closed list? + closed_node = is_node_in_list(child, closed_list) + if closed_node is not None: + continue + + # create the f, g, and h values + child.g = current_node.g + 1 + # this heuristic favors diagonal paths + # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \ + # ((child.pos[1] - end_node.pos[1]) ** 2) + # this heuristic avoids diagonal paths + child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) + child.f = child.g + child.h + + # already in the open list? + open_node = is_node_in_list(child, open_list) + if open_node is not None: + open_node.update_if_better(child) + continue + + # add the child to the open list + open_list.append(child) + + # no full path found, return partial path + if len(open_list) == 0: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + print("partial:", start, end, path[::-1]) + return path[::-1] + + +def connect_rail(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + # print("connecting path", path) + if len(path) < 2: + return + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + pass + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + + +def distance_on_rail(pos1, pos2): + return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) + diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py new file mode 100644 index 0000000000000000000000000000000000000000..fed790173941d7afc5ea5e3956bd131443f3ef57 --- /dev/null +++ b/flatland/envs/generators.py @@ -0,0 +1,478 @@ +import numpy as np + +# from flatland.core.env import Environment +# from flatland.core.env_observation_builder import TreeObsForRailEnv + +from flatland.core.transitions import Grid8Transitions, RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.env_utils import distance_on_rail, connect_rail + + +def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): + """ + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_resets=0): + rail_trans = RailEnvTransitions() + rail_array = np.zeros(shape=(width, height), dtype=np.uint16) + + np.random.seed(seed + num_resets) + + # generate rail array + # step 1: + # - generate a list of start and goal positions + # - use a min/max distance allowed as input for this + # - validate that start/goals are not placed too close to other start/goals + # + # step 2: (optional) + # - place random elements on rails array + # - for instance "train station", etc. + # + # step 3: + # - iterate over all [start, goal] pairs: + # - [first X pairs] + # - draw a rail from [start,goal] + # - draw either vertical or horizontal part first (randomly) + # - if rail crosses existing rail then validate new connection + # - if new connection is invalid turn 90 degrees to left/right + # - possibility that this fails to create a path to goal + # - on failure goto step1 and retry with seed+1 + # - [avoid crossing other start,goal positions] (optional) + # + # - [after X pairs] + # - find closest rail from start (Pa) + # - iterating outwards in a "circle" from start until an existing rail cell is hit + # - connect [start, Pa] + # - validate crossing rails + # - Do A* from Pa to find closest point on rail (Pb) to goal point + # - Basically normal A* but find point on rail which is closest to goal + # - since full path to goal is unlikely + # - connect [Pb, goal] + # - validate crossing rails + # + # step 4: (optional) + # - add more rails to map randomly + # + # step 5: + # - return transition map + list of [start, goal] points + # + + start_goal = [] + for _ in range(nr_start_goal): + sanity_max = 9000 + for _ in range(sanity_max): + start = (np.random.randint(0, width), np.random.randint(0, height)) + goal = (np.random.randint(0, height), np.random.randint(0, height)) + # check to make sure start,goal pos is empty? + if rail_array[goal] != 0 or rail_array[start] != 0: + continue + # check min/max distance + dist_sg = distance_on_rail(start, goal) + if dist_sg < min_dist: + continue + if dist_sg > max_dist: + continue + # check distance to existing points + sg_new = [start, goal] + + def check_all_dist(sg_new): + for sg in start_goal: + for i in range(2): + for j in range(2): + dist = distance_on_rail(sg_new[i], sg[j]) + if dist < 2: + # print("too close:", dist, sg_new[i], sg[j]) + return False + return True + + if check_all_dist(sg_new): + break + start_goal.append([start, goal]) + connect_rail(rail_trans, rail_array, start, goal) + + print("Created #", len(start_goal), "pairs") + # print(start_goal) + + return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) + return_rail.grid = rail_array + # TODO: return start_goal + return return_rail + + return generator + + +def rail_from_manual_specifications_generator(rail_spec): + """ + Utility to convert a rail given by manual specification as a map of tuples + (cell_type, rotation), to a transition map with the correct 16-bit + transitions specifications. + + Parameters + ------- + rail_spec : list of list of tuples + List (rows) of lists (columns) of tuples, each specifying a cell for + the RailEnv environment as (cell_type, rotation), with rotation being + clock-wise and in [0, 90, 180, 270]. + + Returns + ------- + function + Generator function that always returns a GridTransitionMap object with + the matrix of correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + + height = len(rail_spec) + width = len(rail_spec[0]) + rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + + for r in range(height): + for c in range(width): + cell = rail_spec[r][c] + if cell[0] < 0 or cell[0] >= len(t_utils.transitions): + print("ERROR - invalid cell type=", cell[0]) + return [] + rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1])) + + return rail + + return generator + + +def rail_from_GridTransitionMap_generator(rail_map): + """ + Utility to convert a rail given by a GridTransitionMap map with the correct + 16-bit transitions specifications. + + Parameters + ------- + rail_map : GridTransitionMap object + GridTransitionMap object to return when the generator is called. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + + def generator(width, height, num_resets=0): + return rail_map + + return generator + + +def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): + """ + Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset. + + Parameters + ------- + list_of_filenames : list + List of filenames with the saved grids to load. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) + rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) + + if rail_map.grid.dtype == np.uint64: + rail_map.transitions = Grid8Transitions() + + return rail_map + + return generator + + +""" +def generate_rail_from_list_of_manual_specifications(list_of_specifications) + def generator(width, height, num_resets=0): + return generate_rail_from_manual_specifications(list_of_specifications) + + return generator +""" + + +def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): + """ + Dummy random level generator: + - fill in cells at random in [width-2, height-2] + - keep filling cells in among the unfilled ones, such that all transitions + are legit; if no cell can be filled in without violating some + transitions, pick one among those that can satisfy most transitions + (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were + incompatible. + - keep trying for a total number of insertions + (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the + board and try again from scratch. + - finally pad the border of the map with dead-ends to avoid border issues. + + Dead-ends are not allowed inside the grid, only at the border; however, if + no cell type can be inserted in a given cell (because of the neighboring + transitions), deadends are allowed if they solve the problem. This was + found to turn most un-genereatable levels into valid ones. + + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + + transition_probability = cell_type_relative_proportion + + transitions_templates_ = [] + transition_probabilities = [] + for i in range(len(t_utils.transitions) - 4): # don't include dead-ends + all_transitions = 0 + for dir_ in range(4): + trans = t_utils.get_transitions(t_utils.transitions[i], dir_) + all_transitions |= (trans[0] << 3) | \ + (trans[1] << 2) | \ + (trans[2] << 1) | \ + (trans[3]) + + template = [int(x) for x in bin(all_transitions)[2:]] + template = [0] * (4 - len(template)) + template + + # add all rotations + for rot in [0, 90, 180, 270]: + transitions_templates_.append((template, + t_utils.rotate_transition( + t_utils.transitions[i], + rot))) + transition_probabilities.append(transition_probability[i]) + template = [template[-1]] + template[:-1] + + def get_matching_templates(template): + ret = [] + for i in range(len(transitions_templates_)): + is_match = True + for j in range(4): + if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]: + is_match = False + break + if is_match: + ret.append((transitions_templates_[i][1], transition_probabilities[i])) + return ret + + MAX_INSERTIONS = (width - 2) * (height - 2) * 10 + MAX_ATTEMPTS_FROM_SCRATCH = 10 + + attempt_number = 0 + while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: + cells_to_fill = [] + rail = [] + for r in range(height): + rail.append([None] * width) + if r > 0 and r < height - 1: + cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)] + + num_insertions = 0 + while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: + # cell = random.sample(cells_to_fill, 1)[0] + cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]] + cells_to_fill.remove(cell) + row = cell[0] + col = cell[1] + + # look at its neighbors and see what are the possible transitions + # that can be chosen from, if any. + valid_template = [-1, -1, -1, -1] + + for el in [(0, 2, (-1, 0)), + (1, 3, (0, 1)), + (2, 0, (1, 0)), + (3, 1, (0, -1))]: # N, E, S, W + neigh_trans = rail[row + el[2][0]][col + el[2][1]] + if neigh_trans is not None: + # select transition coming from facing direction el[1] and + # moving to direction el[1] + max_bit = 0 + for k in range(4): + max_bit |= t_utils.get_transition(neigh_trans, k, el[1]) + + if max_bit: + valid_template[el[0]] = 1 + else: + valid_template[el[0]] = 0 + + possible_cell_transitions = get_matching_templates(valid_template) + + if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS + # no cell can be filled in without violating some transitions + # can a dead-end solve the problem? + if valid_template.count(1) == 1: + for k in range(4): + if valid_template[k] == 1: + rot = 0 + if k == 0: + rot = 180 + elif k == 1: + rot = 270 + elif k == 2: + rot = 0 + elif k == 3: + rot = 90 + + rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot) + num_insertions += 1 + + break + + else: + # can I get valid transitions by removing a single + # neighboring cell? + bestk = -1 + besttrans = [] + for k in range(4): + tmp_template = valid_template[:] + tmp_template[k] = -1 + possible_cell_transitions = get_matching_templates(tmp_template) + if len(possible_cell_transitions) > len(besttrans): + besttrans = possible_cell_transitions + bestk = k + + if bestk >= 0: + # Replace the corresponding cell with None, append it + # to cells to fill, fill in a transition in the current + # cell. + replace_row = row - 1 + replace_col = col + if bestk == 1: + replace_row = row + replace_col = col + 1 + elif bestk == 2: + replace_row = row + 1 + replace_col = col + elif bestk == 3: + replace_row = row + replace_col = col - 1 + + cells_to_fill.append((replace_row, replace_col)) + rail[replace_row][replace_col] = None + + possible_transitions, possible_probabilities = zip(*besttrans) + possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] + + rail[row][col] = np.random.choice(possible_transitions, + p=possible_probabilities) + num_insertions += 1 + + else: + print('WARNING: still nothing!') + rail[row][col] = int('0000000000000000', 2) + num_insertions += 1 + pass + + else: + possible_transitions, possible_probabilities = zip(*possible_cell_transitions) + possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] + + rail[row][col] = np.random.choice(possible_transitions, + p=possible_probabilities) + num_insertions += 1 + + if num_insertions == MAX_INSERTIONS: + # Failed to generate a valid level; try again for a number of times + attempt_number += 1 + else: + break + + if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: + print('ERROR: failed to generate level') + + # Finally pad the border of the map with dead-ends to avoid border issues; + # at most 1 transition in the neigh cell + for r in range(height): + # Check for transitions coming from [r][1] to WEST + max_bit = 0 + neigh_trans = rail[r][1] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) + max_bit = max_bit | (neigh_trans_from_direction & 1) + if max_bit: + rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270) + else: + rail[r][0] = int('0000000000000000', 2) + + # Check for transitions coming from [r][-2] to EAST + max_bit = 0 + neigh_trans = rail[r][-2] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) + if max_bit: + rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), + 90) + else: + rail[r][-1] = int('0000000000000000', 2) + + for c in range(width): + # Check for transitions coming from [1][c] to NORTH + max_bit = 0 + neigh_trans = rail[1][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) + if max_bit: + rail[0][c] = int('0010000000000000', 2) + else: + rail[0][c] = int('0000000000000000', 2) + + # Check for transitions coming from [-2][c] to SOUTH + max_bit = 0 + neigh_trans = rail[-2][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) + if max_bit: + rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180) + else: + rail[-1][c] = int('0000000000000000', 2) + + # For display only, wrong levels + for r in range(height): + for c in range(width): + if rail[r][c] is None: + rail[r][c] = int('0000000000000000', 2) + + tmp_rail = np.asarray(rail, dtype=np.uint16) + + return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + return_rail.grid = tmp_rail + return return_rail + + return generator +