diff --git a/examples/play_model.py b/examples/play_model.py index 2c18c3e3fbf5e54320f3382ae158f542a2130080..8f0df7cdd7957c432515d04503281d45e3bbdbc2 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,4 +1,4 @@ -from flatland.envs.rail_env import RailEnv, random_rail_generator +from flatland.envs.rail_env import RailEnv, random_rail_generator, complex_rail_generator # from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool from flatland.baselines.dueling_double_dqn import Agent @@ -6,7 +6,7 @@ from collections import deque import torch import random import numpy as np -import matplotlib.pyplot as plt +#import matplotlib.pyplot as plt import time @@ -94,24 +94,23 @@ def main(render=True, delay=0.0): # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) - transition_probability = [0.5, # empty cell - Case 0 - 1.0, # Case 1 - straight - 1.0, # Case 2 - simple switch - 0.3, # Case 3 - diamond drossing - 0.5, # Case 4 - single slip - 0.5, # Case 5 - double slip - 0.2, # Case 6 - symmetrical - 0.0] # Case 7 - dead end + #transition_probability = [0.5, # empty cell - Case 0 + # 1.0, # Case 1 - straight + # 1.0, # Case 2 - simple switch + # 0.3, # Case 3 - diamond crossing + # 0.5, # Case 4 - single slip + # 0.5, # Case 5 - double slip + # 0.2, # Case 6 - symmetrical + # 0.0] # Case 7 - dead end # Example generate a random rail - env = RailEnv(width=15, - height=15, - rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=5) + env = RailEnv(width=15, height=15, + rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=5), + number_of_agents=1) if render: env_renderer = RenderTool(env, gl="QT") - plt.figure(figsize=(5,5)) + # plt.figure(figsize=(5,5)) # fRedis = redis.Redis() handle = env.get_agent_handles() diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index a8cb8d6f49157bafbb65551b53c6612c45565c88..ec9586ed44986bf3d24f352c336d9eb64a074a99 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -531,9 +531,47 @@ class RailEnvTransitions(Grid4Transitions): int('1001011000100001', 2), # Case 4 - single slip int('1100110000110011', 2), # Case 5 - double slip int('0101001000000010', 2), # Case 6 - symmetrical - int('0010000000000000', 2)] # Case 7 - dead end + int('0010000000000000', 2), # Case 7 - dead end + int('0100000000000010', 2), # Case 1b - simple turn right + int('0001001000000000', 2)] # Case 1c - simple turn left def __init__(self): super(RailEnvTransitions, self).__init__( transitions=self.transition_list ) + + def print(self, cell_transition): + print(" NESW") + print("N", format(cell_transition>>(3*4) & 0xF, '04b')) + print("E", format(cell_transition>>(2*4) & 0xF, '04b')) + print("S", format(cell_transition>>(1*4) & 0xF, '04b')) + print("W", format(cell_transition>>(0*4) & 0xF, '04b')) + + def is_valid(self, cell_transition): + """ + Checks if a cell transition is a valid cell setup. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + + Returns + ------- + Boolean + True or False + """ + # i = 0 + for trans in self.transitions: + # print(">", i) + # i += 1 + # self.print(trans) + if cell_transition == trans: + return True + for _ in range(3): + trans = self.rotate_transition(trans, rotation=90) + # self.print(trans) + if cell_transition == trans: + return True + return False + diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ce05ce02ccaec6ca5a8add3adef24fdcead02924..3672caf43d69843bdd88dfa5baed4b2d206a1be8 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -13,6 +13,333 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap +class AStarNode(): + """A node class for A* Pathfinding""" + + def __init__(self, parent=None, pos=None): + self.parent = parent + self.pos = pos + self.g = 0 + self.h = 0 + self.f = 0 + + def __eq__(self, other): + return self.pos == other.pos + + def update_if_better(self, other): + if other.g < self.g: + self.parent = other.parent + self.g = other.g + self.h = other.h + self.f = other.f + + +def get_direction(pos1, pos2): + """ + Assumes pos1 and pos2 are adjacent location on grid. + Returns direction (int) that can be used with transitions. + """ + diff_0 = pos2[0] - pos1[0] + diff_1 = pos2[1] - pos1[1] + if diff_0 < 0: + return 0 + if diff_0 > 0: + return 2 + if diff_1 > 0: + return 1 + if diff_1 < 0: + return 3 + return 0 + + +def mirror(dir): + return (dir + 2) % 4 + + +def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): + # start by getting direction used to get to current node + # and direction from current node to possible child node + new_dir = get_direction(current_pos, new_pos) + if prev_pos is not None: + current_dir = get_direction(prev_pos, current_pos) + else: + current_dir = new_dir + # create new transition that would go to child + new_trans = rail_array[current_pos] + if prev_pos is None: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = rail_array[end_pos] + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + # print("========> end trans") + # rail_trans.print(new_trans_e) + if not rail_trans.is_valid(new_trans_e): + return False + # is transition is valid? + # print("=======> trans") + # rail_trans.print(new_trans) + return rail_trans.is_valid(new_trans) + + +def a_star(rail_trans, rail_array, start, end): + """ + Returns a list of tuples as a path from the given start to end. + If no path is found, returns path to closest point to end. + """ + rail_shape = rail_array.shape + start_node = AStarNode(None, start) + end_node = AStarNode(None, end) + open_list = [] + closed_list = [] + + open_list.append(start_node) + + # this could be optimized + def is_node_in_list(node, the_list): + for o_node in the_list: + if node == o_node: + return o_node + return None + + while len(open_list) > 0: + # get node with current shortest est. path (lowest f) + current_node = open_list[0] + current_index = 0 + for index, item in enumerate(open_list): + if item.f < current_node.f: + current_node = item + current_index = index + + # pop current off open list, add to closed list + open_list.pop(current_index) + closed_list.append(current_node) + + # print("a*:", current_node.pos) + # for cn in closed_list: + # print("closed:", cn.pos) + + # found the goal + if current_node == end_node: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + return path[::-1] + + # generate children + children = [] + if current_node.parent is not None: + prev_pos = current_node.parent.pos + else: + prev_pos = None + for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: + node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) + if node_pos[0] >= rail_shape[0] or \ + node_pos[0] < 0 or \ + node_pos[1] >= rail_shape[1] or \ + node_pos[1] < 0: + continue + + # validate positions + # debug: avoid all current rails + # if rail_array.item(node_pos) != 0: + # continue + + # validate positions + if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): + # print("A*: transition invalid") + continue + + # create new node + new_node = AStarNode(current_node, node_pos) + children.append(new_node) + + # loop through children + for child in children: + # already in closed list? + closed_node = is_node_in_list(child, closed_list) + if closed_node is not None: + continue + + # create the f, g, and h values + child.g = current_node.g + 1 + # this heuristic favors diagonal paths + # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \ + # ((child.pos[1] - end_node.pos[1]) ** 2) + # this heuristic avoids diagonal paths + child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) + child.f = child.g + child.h + + # already in the open list? + open_node = is_node_in_list(child, open_list) + if open_node is not None: + open_node.update_if_better(child) + continue + + # add the child to the open list + open_list.append(child) + + # no full path found, return partial path + if len(open_list) == 0: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + return path[::-1] + + +def connect_rail(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + # print("connecting path", path) + if len(path) < 2: + return + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index+1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = rail_array[end_pos] + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + + +def distance_on_rail(pos1, pos2): + return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) + + +def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): + """ + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_resets=0): + rail_trans = RailEnvTransitions() + rail_array = np.zeros(shape=(width, height), dtype=np.uint16) + + np.random.seed(seed + num_resets) + + # generate rail array + # step 1: + # - generate a list of start and goal positions + # - use a min/max distance allowed as input for this + # - validate that start/goals are not placed too close to other start/goals + # + # step 2: (optional) + # - place random elements on rails array + # - for instance "train station", etc. + # + # step 3: + # - iterate over all [start, goal] pairs: + # - [first X pairs] + # - draw a rail from [start,goal] + # - draw either vertical or horizontal part first (randomly) + # - if rail crosses existing rail then validate new connection + # - if new connection is invalid turn 90 degrees to left/right + # - possibility that this fails to create a path to goal + # - on failure goto step1 and retry with seed+1 + # - [avoid crossing other start,goal positions] (optional) + # + # - [after X pairs] + # - find closest rail from start (Pa) + # - iterating outwards in a "circle" from start until an existing rail cell is hit + # - connect [start, Pa] + # - validate crossing rails + # - Do A* from Pa to find closest point on rail (Pb) to goal point + # - Basically normal A* but find point on rail which is closest to goal + # - since full path to goal is unlikely + # - connect [Pb, goal] + # - validate crossing rails + # + # step 4: (optional) + # - add more rails to map randomly + # + # step 5: + # - return transition map + list of [start, goal] points + # + + # step 1: + start_goal = [] + for _ in range(nr_start_goal): + sanity_max = 9000 + for _ in range(sanity_max): + start = (np.random.randint(0, width), np.random.randint(0, height)) + goal = (np.random.randint(0, height), np.random.randint(0, height)) + # check min/max distance + dist_sg = distance_on_rail(start, goal) + if dist_sg < min_dist: + continue + if dist_sg > max_dist: + continue + # check distance to existing points + sg_new = [start, goal] + def check_all_dist(sg_new): + for sg in start_goal: + for i in range(2): + for j in range(2): + dist = distance_on_rail(sg_new[i], sg[j]) + if dist < 2: + # print("too close:", dist, sg_new[i], sg[j]) + return False + return True + if check_all_dist(sg_new): + break + start_goal.append([start, goal]) + print("Created #", len(start_goal), "pairs") + + # step 3: + for sg in start_goal: + connect_rail(rail_trans, rail_array, sg[0], sg[1]) + + return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) + return_rail.grid = rail_array + # TODO: return start_goal + return return_rail + + return generator + + def rail_from_manual_specifications_generator(rail_spec): """ Utility to convert a rail given by manual specification as a map of tuples @@ -148,7 +475,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions) - 1): # don't include dead-ends + for i in range(len(t_utils.transitions) - 3): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 0f56e886071fd1d217be03b9a7e875c20d1a0e8a..1d6ea966a318b5dda7a6e61c87343fc0710e72ea 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -3,6 +3,56 @@ """Tests for `flatland` package.""" from flatland.core.transitions import RailEnvTransitions, Grid8Transitions +from flatland.envs.rail_env import validate_new_transition +import numpy as np + + +def test_is_valid_railenv_transitions(): + rail_env_trans = RailEnvTransitions() + transition_list = rail_env_trans.transitions + + for t in transition_list: + assert(rail_env_trans.is_valid(t) == True) + for i in range(3): + rot_trans = rail_env_trans.rotate_transition(t, 90 * i) + assert(rail_env_trans.is_valid(rot_trans) == True) + + assert(rail_env_trans.is_valid(int('1111111111110010', 2)) == False) + assert(rail_env_trans.is_valid(int('1001111111110010', 2)) == False) + assert(rail_env_trans.is_valid(int('1001111001110110', 2)) == False) + + +def test_adding_new_valid_transition(): + rail_trans = RailEnvTransitions() + rail_array = np.zeros(shape=(15, 15), dtype=np.uint16) + + # adding straight + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (6,5), (10,10)) == True) + + # adding valid right turn + assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (5,6), (10,10)) == True) + # adding valid left turn + assert(validate_new_transition(rail_trans, rail_array, (5,6), (5,5), (5,6), (10,10)) == True) + + # adding invalid turn + rail_array[(5,5)] = rail_trans.transitions[2] + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False) + + # should create #4 -> valid + rail_array[(5,5)] = rail_trans.transitions[3] + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == True) + + # adding invalid turn + rail_array[(5,5)] = rail_trans.transitions[7] + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False) + + # test path start condition + rail_array[(5,5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, None, (5,5), (5,6), (10,10)) == True) + + # test path end condition + rail_array[(5,5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (6,5), (6,5)) == True) def test_valid_railenv_transitions():