diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py new file mode 100644 index 0000000000000000000000000000000000000000..feb72313f21b9ecc989688d63ba02ccf3a458107 --- /dev/null +++ b/flatland/core/grid/grid4_astar.py @@ -0,0 +1,106 @@ +from flatland.core.grid.grid4_utils import validate_new_transition + + +class AStarNode(): + """A node class for A* Pathfinding""" + + def __init__(self, parent=None, pos=None): + self.parent = parent + self.pos = pos + self.g = 0 + self.h = 0 + self.f = 0 + + def __eq__(self, other): + return self.pos == other.pos + + def __hash__(self): + return hash(self.pos) + + def update_if_better(self, other): + if other.g < self.g: + self.parent = other.parent + self.g = other.g + self.h = other.h + self.f = other.f + + +def a_star(rail_trans, rail_array, start, end): + """ + Returns a list of tuples as a path from the given start to end. + If no path is found, returns path to closest point to end. + """ + rail_shape = rail_array.shape + start_node = AStarNode(None, start) + end_node = AStarNode(None, end) + open_nodes = set() + closed_nodes = set() + open_nodes.add(start_node) + + while len(open_nodes) > 0: + # get node with current shortest est. path (lowest f) + current_node = None + for item in open_nodes: + if current_node is None: + current_node = item + continue + if item.f < current_node.f: + current_node = item + + # pop current off open list, add to closed list + open_nodes.remove(current_node) + closed_nodes.add(current_node) + + # found the goal + if current_node == end_node: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + return path[::-1] + + # generate children + children = [] + if current_node.parent is not None: + prev_pos = current_node.parent.pos + else: + prev_pos = None + for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: + node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) + if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: + continue + + # validate positions + if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): + continue + + # create new node + new_node = AStarNode(current_node, node_pos) + children.append(new_node) + + # loop through children + for child in children: + # already in closed list? + if child in closed_nodes: + continue + + # create the f, g, and h values + child.g = current_node.g + 1 + # this heuristic favors diagonal paths: + # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \# noqa: E800 + # this heuristic avoids diagonal paths + child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) + child.f = child.g + child.h + + # already in the open list? + if child in open_nodes: + continue + + # add the child to the open list + open_nodes.add(child) + + # no full path found + if len(open_nodes) == 0: + return [] diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d64a160b6b422b5984ae675ab4544604a74b1337 --- /dev/null +++ b/flatland/core/grid/grid4_utils.py @@ -0,0 +1,74 @@ +from flatland.core.grid.grid4 import Grid4TransitionsEnum + + +def get_direction(pos1, pos2) -> Grid4TransitionsEnum: + """ + Assumes pos1 and pos2 are adjacent location on grid. + Returns direction (int) that can be used with transitions. + """ + diff_0 = pos2[0] - pos1[0] + diff_1 = pos2[1] - pos1[1] + if diff_0 < 0: + return 0 + if diff_0 > 0: + return 2 + if diff_1 > 0: + return 1 + if diff_1 < 0: + return 3 + raise Exception("Could not determine direction {}->{}".format(pos1, pos2)) + + +def mirror(dir): + return (dir + 2) % 4 + + +def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): + # start by getting direction used to get to current node + # and direction from current node to possible child node + new_dir = get_direction(current_pos, new_pos) + if prev_pos is not None: + current_dir = get_direction(prev_pos, current_pos) + else: + current_dir = new_dir + # create new transition that would go to child + new_trans = rail_array[current_pos] + if prev_pos is None: + if new_trans == 0: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # check if matches existing layout + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # need to flip direction because of how end points are defined + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # check if matches existing layout + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + + if not rail_trans.is_valid(new_trans_e): + return False + + # is transition is valid? + return rail_trans.is_valid(new_trans) + + +def get_new_position(position, movement): + """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ + if movement == Grid4TransitionsEnum.NORTH: + return (position[0] - 1, position[1]) + elif movement == Grid4TransitionsEnum.EAST: + return (position[0], position[1] + 1) + elif movement == Grid4TransitionsEnum.SOUTH: + return (position[0] + 1, position[1]) + elif movement == Grid4TransitionsEnum.WEST: + return (position[0], position[1] - 1) diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f51d285947ae175692ca70e1f03d08b8a22a5a22 --- /dev/null +++ b/flatland/core/grid/grid_utils.py @@ -0,0 +1,58 @@ +import numpy as np + + +def position_to_coordinate(depth, positions): + """Converts coordinates to positions: + [ (0,0) (0,1) .. (0,w-1) + (1,0) (1,1) (1,w-1) + ... + (d-1,0) (d-1,1) (d-1,w-1) + ] + + --> + + [ 0 d .. (w-1)*d + 1 d+1 + ... + d-1 2d-1 w*d-1 + ] + + :param depth: + :param positions: + :return: + """ + coords = () + for p in positions: + coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim + return coords + + +def coordinate_to_position(depth, coords): + """ + Converts positions to coordinates: + [ 0 d .. (w-1)*d + 1 d+1 + ... + d-1 2d-1 w*d-1 + ] + --> + [ (0,0) (0,1) .. (0,w-1) + (1,0) (1,1) (1,w-1) + ... + (d-1,0) (d-1,1) (d-1,w-1) + ] + + :param depth: + :param coords: + :return: + """ + position = np.empty(len(coords), dtype=int) + idx = 0 + for t in coords: + position[idx] = int(t[1] * depth + t[0]) + idx += 1 + return position + + +def distance_on_rail(pos1, pos2): + return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py index c043b42f1ca84ba9d0f7a68f5e18a192ff374d7a..680e945316ab3a4876bd36fa8e6b001ea346cd26 100644 --- a/flatland/core/grid/rail_env_grid.py +++ b/flatland/core/grid/rail_env_grid.py @@ -62,38 +62,6 @@ class RailEnvTransitions(Grid4Transitions): print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) - def repr(self, cell_transition, version=0): - """ - Provide a string representation of the cell transitions. - This class doesn't represent an individual cell, - but a way of interpreting the contents of a cell. - So using the ad hoc name repr rather than __repr__. - """ - # binary format string without leading 0b - sbinTrans = format(cell_transition, "#018b")[2:] - if version == 0: - sRepr = " ".join([ - "{}:{}".format(sDir, sbinTrans[i:(i + 4)]) - for i, sDir in - zip( - range(0, len(sbinTrans), 4), - self.lsDirs)]) # NESW - return sRepr - - if version == 1: - lsRepr = [] - for iDirIn in range(0, 4): - sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)] - if sDirTrans == "0000": - continue - sDirsOut = [ - self.lsDirs[iDirOut] - for iDirOut in range(0, 4) - if sDirTrans[iDirOut] == "1"] - lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut)) - - return ", ".join(lsRepr) - def is_valid(self, cell_transition): """ Checks if a cell transition is a valid cell setup. diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index aa46aecd4b69b6a13b11b63223123b16dd69e3ac..5eadb9332a4a084ad030a46dc18bab81f33ac4e0 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -4,20 +4,6 @@ import numpy as np from attr import attrs, attrib -@attrs -class EnvDescription(object): - """ EnvDescription - This is a description of a random env, - based around the rail_generator and stats like size and n_agents. - It mirrors the parameters given to the RailEnv constructor. - Not currently used. - """ - n_agents = attrib() - height = attrib() - width = attrib() - rail_generator = attrib() - obs_builder = attrib() # not sure if this should closer to the agent than the env - - @attrs class EnvAgentStatic(object): """ EnvAgentStatic - Stores initial position, direction and target. @@ -34,18 +20,6 @@ class EnvAgentStatic(object): # cell if speed=1, as default) speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})) - def __init__(self, - position, - direction, - target, - moving=False, - speed_data={'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}): - self.position = position - self.direction = direction - self.target = target - self.moving = moving - self.speed_data = speed_data - @classmethod def from_lists(cls, positions, directions, targets, speeds=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets @@ -84,12 +58,6 @@ class EnvAgent(EnvAgentStatic): old_direction = attrib(default=None) old_position = attrib(default=None) - def __init__(self, position, direction, target, handle, old_direction, old_position): - super(EnvAgent, self).__init__(position, direction, target) - self.handle = handle - self.old_direction = old_direction - self.old_position = old_position - def to_list(self): return [ self.position, self.direction, self.target, self.handle, diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py deleted file mode 100644 index 19da8946864245343a09857d8b2fbe968f47ba35..0000000000000000000000000000000000000000 --- a/flatland/envs/env_utils.py +++ /dev/null @@ -1,369 +0,0 @@ -""" -Definition of the RailEnv environment and related level-generation functions. - -Generator functions are functions that take width, height and num_resets as arguments and return -a GridTransitionMap object. -""" - -import numpy as np - -from flatland.core.grid.grid4 import Grid4TransitionsEnum - - -def get_direction(pos1, pos2) -> Grid4TransitionsEnum: - """ - Assumes pos1 and pos2 are adjacent location on grid. - Returns direction (int) that can be used with transitions. - """ - diff_0 = pos2[0] - pos1[0] - diff_1 = pos2[1] - pos1[1] - if diff_0 < 0: - return 0 - if diff_0 > 0: - return 2 - if diff_1 > 0: - return 1 - if diff_1 < 0: - return 3 - raise Exception("Could not determine direction {}->{}".format(pos1, pos2)) - - -def mirror(dir): - return (dir + 2) % 4 - - -def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): - # start by getting direction used to get to current node - # and direction from current node to possible child node - new_dir = get_direction(current_pos, new_pos) - if prev_pos is not None: - current_dir = get_direction(prev_pos, current_pos) - else: - current_dir = new_dir - # create new transition that would go to child - new_trans = rail_array[current_pos] - if prev_pos is None: - if new_trans == 0: - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # check if matches existing layout - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - if new_pos == end_pos: - # need to validate end pos setup as well - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # need to flip direction because of how end points are defined - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # check if matches existing layout - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - - if not rail_trans.is_valid(new_trans_e): - return False - - # is transition is valid? - return rail_trans.is_valid(new_trans) - - -def position_to_coordinate(depth, positions): - """Converts coordinates to positions: - [ (0,0) (0,1) .. (0,w-1) - (1,0) (1,1) (1,w-1) - ... - (d-1,0) (d-1,1) (d-1,w-1) - ] - - --> - - [ 0 d .. (w-1)*d - 1 d+1 - ... - d-1 2d-1 w*d-1 - ] - - :param depth: - :param positions: - :return: - """ - coords = () - for p in positions: - coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim - return coords - - -def coordinate_to_position(depth, coords): - """ - Converts positions to coordinates: - [ 0 d .. (w-1)*d - 1 d+1 - ... - d-1 2d-1 w*d-1 - ] - --> - [ (0,0) (0,1) .. (0,w-1) - (1,0) (1,1) (1,w-1) - ... - (d-1,0) (d-1,1) (d-1,w-1) - ] - - :param depth: - :param coords: - :return: - """ - position = np.empty(len(coords), dtype=int) - idx = 0 - for t in coords: - position[idx] = int(t[1] * depth + t[0]) - idx += 1 - return position - - -def get_new_position(position, movement): - """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ - if movement == Grid4TransitionsEnum.NORTH: - return (position[0] - 1, position[1]) - elif movement == Grid4TransitionsEnum.EAST: - return (position[0], position[1] + 1) - elif movement == Grid4TransitionsEnum.SOUTH: - return (position[0] + 1, position[1]) - elif movement == Grid4TransitionsEnum.WEST: - return (position[0], position[1] - 1) - - -class AStarNode(): - """A node class for A* Pathfinding""" - - def __init__(self, parent=None, pos=None): - self.parent = parent - self.pos = pos - self.g = 0 - self.h = 0 - self.f = 0 - - def __eq__(self, other): - return self.pos == other.pos - - def __hash__(self): - return hash(self.pos) - - def update_if_better(self, other): - if other.g < self.g: - self.parent = other.parent - self.g = other.g - self.h = other.h - self.f = other.f - - -def a_star(rail_trans, rail_array, start, end): - """ - Returns a list of tuples as a path from the given start to end. - If no path is found, returns path to closest point to end. - """ - rail_shape = rail_array.shape - start_node = AStarNode(None, start) - end_node = AStarNode(None, end) - open_nodes = set() - closed_nodes = set() - open_nodes.add(start_node) - - while len(open_nodes) > 0: - # get node with current shortest est. path (lowest f) - current_node = None - for item in open_nodes: - if current_node is None: - current_node = item - continue - if item.f < current_node.f: - current_node = item - - # pop current off open list, add to closed list - open_nodes.remove(current_node) - closed_nodes.add(current_node) - - # found the goal - if current_node == end_node: - path = [] - current = current_node - while current is not None: - path.append(current.pos) - current = current.parent - # return reversed path - return path[::-1] - - # generate children - children = [] - if current_node.parent is not None: - prev_pos = current_node.parent.pos - else: - prev_pos = None - for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: - node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) - if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: - continue - - # validate positions - if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): - continue - - # create new node - new_node = AStarNode(current_node, node_pos) - children.append(new_node) - - # loop through children - for child in children: - # already in closed list? - if child in closed_nodes: - continue - - # create the f, g, and h values - child.g = current_node.g + 1 - # this heuristic favors diagonal paths: - # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \# noqa: E800 - # this heuristic avoids diagonal paths - child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) - child.f = child.g + child.h - - # already in the open list? - if child in open_nodes: - continue - - # add the child to the open list - open_nodes.add(child) - - # no full path found - if len(open_nodes) == 0: - return [] - - -def connect_rail(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - return path - - -def distance_on_rail(pos1, pos2): - return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) - - -def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): - """ - Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). - - TODO: add extensive documentation, as users may need this function to simplify their custom level generators. - """ - - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = rail.get_transitions((node[0][0], node[0][1], node[1])) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = rail.get_transitions((node[0][0], node[0][1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - valid_positions = [] - for r in range(rail.height): - for c in range(rail.width): - if rail.get_transitions((r, c)) > 0: - valid_positions.append((r, c)) - - re_generate = True - while re_generate: - agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - agents_direction = [0] * num_agents - re_generate = False - for i in range(num_agents): - valid_movements = [] - for direction in range(4): - position = agents_position[i] - moves = rail.get_transitions((position[0], position[1], direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] - - return agents_position, agents_direction, agents_target diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index fa5cdccc91fd837ac7b034564dffb5f909b81a1e..c3f569ae0d927a6e71803a24d921337c65d39c29 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -2,8 +2,10 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap from flatland.core.grid.rail_env_grid import RailEnvTransitions -from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror -from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.grid4_generators_utils import connect_rail +from flatland.core.grid.grid_utils import distance_on_rail +from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail def empty_rail_generator(): diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2ab8cbc8b956698383cd3961d77df6fb1ef195 --- /dev/null +++ b/flatland/envs/grid4_generators_utils.py @@ -0,0 +1,135 @@ +""" +Definition of the RailEnv environment and related level-generation functions. + +Generator functions are functions that take width, height and num_resets as arguments and return +a GridTransitionMap object. +""" + +import numpy as np + +from flatland.core.grid.grid4_astar import a_star +from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position + + +def connect_rail(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): + """ + Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). + + TODO: add extensive documentation, as users may need this function to simplify their custom level generators. + """ + + def _path_exists(rail, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = rail.get_transitions((node[0][0], node[0][1], node[1])) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = rail.get_transitions((node[0][0], node[0][1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + valid_positions = [] + for r in range(rail.height): + for c in range(rail.width): + if rail.get_transitions((r, c)) > 0: + valid_positions.append((r, c)) + + re_generate = True + while re_generate: + agents_position = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + agents_target = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + + # agents_direction must be a direction for which a solution is + # guaranteed. + agents_direction = [0] * num_agents + re_generate = False + for i in range(num_agents): + valid_movements = [] + for direction in range(4): + position = agents_position[i] + moves = rail.get_transitions((position[0], position[1], direction)) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = get_new_position(agents_position[i], m[1]) + if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] + + return agents_position, agents_direction, agents_target diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index ecb0697899ab64ce5bb455c33db925b8561bd14a..8c3f260ba83e91e7657b7a0b3f8a70c5d228f9b1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -7,7 +7,7 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.env_utils import coordinate_to_position +from flatland.core.grid.grid_utils import coordinate_to_position class TreeObsForRailEnv(ObservationBuilder): diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 43909669a67c527ae6fb935e22810eb47c9608cd..654f5490da875afe746d516892a773308b466589 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -5,7 +5,7 @@ Collection of environment-specific PredictionBuilder. import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.envs.env_utils import get_new_position +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.rail_env import RailEnvActions diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8cf6d52f383ec8f4e271eb0765d32bc0c763307a..05738ecc407f275ecb0ef9c1d9b72e50a19de695 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -13,7 +13,7 @@ import numpy as np from flatland.core.env import Environment from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent -from flatland.envs.env_utils import get_new_position +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv diff --git a/flatland/flatland.py b/flatland/flatland.py deleted file mode 100644 index 7fbbae4f9c58882c3754a89675312f3c1430ffd8..0000000000000000000000000000000000000000 --- a/flatland/flatland.py +++ /dev/null @@ -1,3 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Main module.""" diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 9979122b2b295b19ddcbfbf35e61f4848853321e..f58133560c4c51e6feec5b50861795befc2e9625 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -10,7 +10,7 @@ from numpy import array import flatland.utils.rendertools as rt from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic -from flatland.envs.env_utils import mirror +from flatland.core.grid.grid4_utils import mirror from flatland.envs.generators import complex_rail_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, random_rail_generator diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index b2e0284407d52cbbc53ae74c0a744a457d7fdbeb..e8399d51b4f9b5a8d1d1d0e446c695f799c84381 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -14,9 +14,6 @@ class SVG(object): elif svgETree is not None: self.svg = svgETree - self.init2() - - def init2(self): expr = "//*[local-name() = $name]" self.eStyle = self.svg.root.xpath(expr, name="style")[0] ltMatch = re.findall(r".st([a-zA-Z0-9]+)[{]([^}]*)}", self.eStyle.text) @@ -25,8 +22,7 @@ class SVG(object): def copy(self): new_svg = copy.deepcopy(self.svg) - self2 = SVG(svgETree=new_svg) - return self2 + return SVG(svgETree=new_svg) def merge(self, svg2): svg3 = svg2.copy() diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 048520c17eeddf4d1f4a4c6beeb49427887b77f4..9d4a72c056f9825383f1c9d055509e870219fd9f 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -7,7 +7,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid8 import Grid8Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions -from flatland.envs.env_utils import validate_new_transition +from flatland.core.grid.grid4_utils import validate_new_transition # remove whitespace in string; keep whitespace below for easier reading diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py index 49b619a159870c9105137ed41a8b55aa1dd19e36..b95922cf67febdaa0aad396459bc446bc31adfea 100644 --- a/tests/test_flatland_envs_env_utils.py +++ b/tests/test_flatland_envs_env_utils.py @@ -2,7 +2,8 @@ import numpy as np import pytest from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction +from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position +from flatland.core.grid.grid4_utils import get_direction depth_to_test = 5 positions_to_test = [0, 5, 1, 6, 20, 30]