diff --git a/examples/play_model.py b/examples/play_model.py index 6a67397ea4ba8d8906ce62b1a8d21327c247a3e0..46a3fb50a0ebba315745c7e729ddc8f26cb0d0ba 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,4 +1,4 @@ -from flatland.envs.rail_env import RailEnv, random_rail_generator +from flatland.envs.rail_env import RailEnv, random_rail_generator, complex_rail_generator # from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool from flatland.baselines.dueling_double_dqn import Agent @@ -17,20 +17,19 @@ def main(render=True, delay=0.0): # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) - transition_probability = [0.5, # empty cell - Case 0 - 1.0, # Case 1 - straight - 1.0, # Case 2 - simple switch - 0.3, # Case 3 - diamond drossing - 0.5, # Case 4 - single slip - 0.5, # Case 5 - double slip - 0.2, # Case 6 - symmetrical - 0.0] # Case 7 - dead end + #transition_probability = [0.5, # empty cell - Case 0 + # 1.0, # Case 1 - straight + # 1.0, # Case 2 - simple switch + # 0.3, # Case 3 - diamond drossing + # 0.5, # Case 4 - single slip + # 0.5, # Case 5 - double slip + # 0.2, # Case 6 - symmetrical + # 0.0] # Case 7 - dead end # Example generate a random rail - env = RailEnv(width=15, - height=15, - rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=5) + env = RailEnv(width=15, height=15, + rail_generator=complex_rail_generator(), + number_of_agents=1) if render: env_renderer = RenderTool(env, gl="QT") diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index a8cb8d6f49157bafbb65551b53c6612c45565c88..a0becddd47f82f8108574cec83e5c7b55f73033a 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -537,3 +537,27 @@ class RailEnvTransitions(Grid4Transitions): super(RailEnvTransitions, self).__init__( transitions=self.transition_list ) + + def is_valid(self, cell_transition): + """ + Checks if a cell transition is a valid cell setup. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + + Returns + ------- + Boolean + True or False + """ + for trans in self.transitions: + if cell_transition == trans: + return True + for _ in range(3): + trans = self.rotate_transition(trans, rotation=90) + if cell_transition == trans: + return True + return False + diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 3fadf66ccf185329c49d2ea105728e3faedf0ae5..aeb3621db9dd15e7eea74ed99bc652b0e336a838 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -13,6 +13,254 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap +class AStarNode(): + """A node class for A* Pathfinding""" + + def __init__(self, parent=None, pos=None): + self.parent = parent + self.pos = pos + self.g = 0 + self.h = 0 + self.f = 0 + + def __eq__(self, other): + return self.pos == other.pos + + def update_if_better(self, other): + if other.g < self.g: + self.parent = other.parent + self.g = other.g + self.h = other.h + self.f = other.f + + +def a_star(rail_array, start, end): + """ + Returns a list of tuples as a path from the given start to end. + If no path is found, returns path to closest point to end. + """ + rail_shape = rail_array.shape + start_node = AStarNode(None, start) + end_node = AStarNode(None, end) + open_list = [] + closed_list = [] + + open_list.append(start_node) + + # this could be optimized + def is_node_in_list(node, the_list): + for o_node in the_list: + if node == o_node: + return o_node + return None + + while len(open_list) > 0: + # get node with current shortest est. path (lowest f) + current_node = open_list[0] + current_index = 0 + for index, item in enumerate(open_list): + if item.f < current_node.f: + current_node = item + current_index = index + + # pop current off open list, add to closed list + open_list.pop(current_index) + closed_list.append(current_node) + + # print("a*:", current_node.pos) + # for cn in closed_list: + # print("closed:", cn.pos) + + # found the goal + if current_node == end_node: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + return path[::-1] + + # generate children + children = [] + for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: + node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) + if node_pos[0] >= rail_shape[0] or \ + node_pos[0] < 0 or \ + node_pos[1] >= rail_shape[1] or \ + node_pos[1] < 0: + continue + + # validate positions + # debug: avoid all current rails + # if rail_array.item(node_pos) != 0: + # continue + + # create new node + new_node = AStarNode(current_node, node_pos) + children.append(new_node) + + # loop through children + for child in children: + # already in closed list? + closed_node = is_node_in_list(child, closed_list) + if closed_node is not None: + continue + + # create the f, g, and h values + child.g = current_node.g + 1 + # this heuristic favors diagonal paths + # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \ + # ((child.pos[1] - end_node.pos[1]) ** 2) + # this heuristic avoids diagonal paths + child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) + child.f = child.g + child.h + + # already in the open list? + open_node = is_node_in_list(child, open_list) + if open_node is not None: + open_node.update_if_better(child) + continue + + # add the child to the open list + open_list.append(child) + + # no full path found, return partial path + if len(open_list) == 0: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + return path[::-1] + + +def complex_rail_generator(nr_start_goal=10, min_dist=0, max_dist=99999, seed=0): + """ + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_resets=0): + rail_trans = RailEnvTransitions() + rail_array = np.zeros(shape=(width, height), dtype=np.uint16) + + np.random.seed(seed) + + # generate rail array + # step 1: + # - generate a list of start and goal positions + # - use a min/max distance allowed as input for this + # - validate that start/goals are not placed too close to other start/goals + # + # step 2: (optional) + # - place random elements on rails array + # - for instance "train station", etc. + # + # step 3: + # - iterate over all [start, goal] pairs: + # - [first X pairs] + # - draw a rail from [start,goal] + # - draw either vertical or horizontal part first (randomly) + # - if rail crosses existing rail then validate new connection + # - if new connection is invalid turn 90 degrees to left/right + # - possibility that this fails to create a path to goal + # - on failure goto step1 and retry with seed+1 + # - [avoid crossing other start,goal positions] (optional) + # + # - [after X pairs] + # - find closest rail from start (Pa) + # - iterating outwards in a "circle" from start until an existing rail cell is hit + # - connect [start, Pa] + # - validate crossing rails + # - Do A* from Pa to find closest point on rail (Pb) to goal point + # - Basically normal A* but find point on rail which is closest to goal + # - since full path to goal is unlikely + # - connect [Pb, goal] + # - validate crossing rails + # + # step 4: (optional) + # - add more rails to map randomly + # + # step 5: + # - return transition map + list of [start, goal] points + # + + start_goal = [] + for _ in range(nr_start_goal): + start = (np.random.randint(0, width), np.random.randint(0, height)) + goal = (np.random.randint(0, height), np.random.randint(0, height)) + # TODO: validate closeness with existing points + # TODO: make sure min/max distance condition is met + start_goal.append([start, goal]) + + def get_direction(pos1, pos2): + diff_0 = pos2[0] - pos1[0] + diff_1 = pos2[1] - pos1[1] + if diff_0 < 0: + return 0 + if diff_0 > 0: + return 2 + if diff_1 > 0: + return 1 + if diff_1 < 0: + return 3 + return 0 + + def connect_two_cells(pos1, pos2): + # connect two adjacent cells + direction = get_direction(pos1, pos2) + rail_array[pos1] = rail_trans.set_transition(rail_array[pos1], direction, direction, 1) + o_dir = (direction + 2) % 4 + rail_array[pos2] = rail_trans.set_transition(rail_array[pos2], o_dir, o_dir, 1) + + def connect_rail(start, end): + # in the worst case we will need to do a A* search, so we might as well set that up + # TODO: need to check transitions in A* to see if new path is valid + path = a_star(rail_array, start, end) + print("connecting path", path) + if len(path) < 2: + return + if len(path) == 2: + connect_two_cells(path[0], path[1]) + return + current_dir = get_direction(path[0], path[1]) + for index in range(len(path)): + pos1 = path[index] + if index+1 < len(path): + new_dir = get_direction(pos1, path[index+1]) + else: + new_dir = current_dir + cell_trans = rail_array[pos1] + if index != len(path)-1: + # set the forward path + cell_trans = rail_trans.set_transition(cell_trans, current_dir, new_dir, 1) + if index != 0: + # set the backwards path + cell_trans = rail_trans.set_transition(cell_trans, (new_dir+2) % 4, (current_dir+2) % 4, 1) + rail_array[pos1] = cell_trans + current_dir = new_dir + + for sg in start_goal: + connect_rail(sg[0], sg[1]) + + return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) + return_rail.grid = rail_array + return return_rail + + return generator + + def rail_from_manual_specifications_generator(rail_spec): """ Utility to convert a rail given by manual specification as a map of tuples