From 0a151b327ec4894366491dfce1f16e5725a25465 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Wed, 1 May 2019 20:16:19 +0100 Subject: [PATCH] split up rail_env.py into rail_env, generators and env_utils.py --- examples/play_model.py | 3 +- flatland/envs/rail_env.py | 733 +------------------------- flatland/utils/editor.py | 24 +- tests/test_env_observation_builder.py | 3 +- tests/test_environments.py | 4 +- tests/test_transitions.py | 3 +- 6 files changed, 27 insertions(+), 743 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index 1f654c1..e6e81c9 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,4 +1,5 @@ -from flatland.envs.rail_env import RailEnv, complex_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.envs.generators import complex_rail_generator # from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool from flatland.baselines.dueling_double_dqn import Agent diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8969af3..eb8786c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -8,737 +8,10 @@ import numpy as np from flatland.core.env import Environment from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.generators import random_rail_generator -from flatland.core.transitions import Grid8Transitions, RailEnvTransitions -from flatland.core.transition_map import GridTransitionMap - - -class AStarNode(): - """A node class for A* Pathfinding""" - - def __init__(self, parent=None, pos=None): - self.parent = parent - self.pos = pos - self.g = 0 - self.h = 0 - self.f = 0 - - def __eq__(self, other): - return self.pos == other.pos - - def update_if_better(self, other): - if other.g < self.g: - self.parent = other.parent - self.g = other.g - self.h = other.h - self.f = other.f - - -def get_direction(pos1, pos2): - """ - Assumes pos1 and pos2 are adjacent location on grid. - Returns direction (int) that can be used with transitions. - """ - diff_0 = pos2[0] - pos1[0] - diff_1 = pos2[1] - pos1[1] - if diff_0 < 0: - return 0 - if diff_0 > 0: - return 2 - if diff_1 > 0: - return 1 - if diff_1 < 0: - return 3 - return 0 - - -def mirror(dir): - return (dir + 2) % 4 - - -def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): - # start by getting direction used to get to current node - # and direction from current node to possible child node - new_dir = get_direction(current_pos, new_pos) - if prev_pos is not None: - current_dir = get_direction(prev_pos, current_pos) - else: - current_dir = new_dir - # create new transition that would go to child - new_trans = rail_array[current_pos] - if prev_pos is None: - if new_trans == 0: - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # check if matches existing layout - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - # rail_trans.print(new_trans) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - if new_pos == end_pos: - # need to validate end pos setup as well - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # need to flip direction because of how end points are defined - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # check if matches existing layout - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) - # print("end:", end_pos, current_pos) - # rail_trans.print(new_trans_e) - - # print("========> end trans") - # rail_trans.print(new_trans_e) - if not rail_trans.is_valid(new_trans_e): - # print("end failed", end_pos, current_pos) - return False - # else: - # print("end ok!", end_pos, current_pos) - - # is transition is valid? - # print("=======> trans") - # rail_trans.print(new_trans) - return rail_trans.is_valid(new_trans) - - -def a_star(rail_trans, rail_array, start, end): - """ - Returns a list of tuples as a path from the given start to end. - If no path is found, returns path to closest point to end. - """ - rail_shape = rail_array.shape - start_node = AStarNode(None, start) - end_node = AStarNode(None, end) - open_list = [] - closed_list = [] - - open_list.append(start_node) - - # this could be optimized - def is_node_in_list(node, the_list): - for o_node in the_list: - if node == o_node: - return o_node - return None - - while len(open_list) > 0: - # get node with current shortest est. path (lowest f) - current_node = open_list[0] - current_index = 0 - for index, item in enumerate(open_list): - if item.f < current_node.f: - current_node = item - current_index = index - - # pop current off open list, add to closed list - open_list.pop(current_index) - closed_list.append(current_node) - - # print("a*:", current_node.pos) - # for cn in closed_list: - # print("closed:", cn.pos) - - # found the goal - if current_node == end_node: - path = [] - current = current_node - while current is not None: - path.append(current.pos) - current = current.parent - # return reversed path - return path[::-1] - - # generate children - children = [] - if current_node.parent is not None: - prev_pos = current_node.parent.pos - else: - prev_pos = None - for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: - node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) - if node_pos[0] >= rail_shape[0] or \ - node_pos[0] < 0 or \ - node_pos[1] >= rail_shape[1] or \ - node_pos[1] < 0: - continue - - # validate positions - # debug: avoid all current rails - # if rail_array.item(node_pos) != 0: - # continue - - # validate positions - if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): - # print("A*: transition invalid") - continue - - # create new node - new_node = AStarNode(current_node, node_pos) - children.append(new_node) - - # loop through children - for child in children: - # already in closed list? - closed_node = is_node_in_list(child, closed_list) - if closed_node is not None: - continue - - # create the f, g, and h values - child.g = current_node.g + 1 - # this heuristic favors diagonal paths - # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \ - # ((child.pos[1] - end_node.pos[1]) ** 2) - # this heuristic avoids diagonal paths - child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) - child.f = child.g + child.h - - # already in the open list? - open_node = is_node_in_list(child, open_list) - if open_node is not None: - open_node.update_if_better(child) - continue - - # add the child to the open list - open_list.append(child) - - # no full path found, return partial path - if len(open_list) == 0: - path = [] - current = current_node - while current is not None: - path.append(current.pos) - current = current.parent - # return reversed path - print("partial:", start, end, path[::-1]) - return path[::-1] - - -def connect_rail(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - # print("connecting path", path) - if len(path) < 2: - return - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - pass - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - - -def distance_on_rail(pos1, pos2): - return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) - - -def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): - """ - Parameters - ------- - width : int - The width (number of cells) of the grid to generate. - height : int - The height (number of cells) of the grid to generate. - - Returns - ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. - """ - - def generator(width, height, num_resets=0): - rail_trans = RailEnvTransitions() - rail_array = np.zeros(shape=(width, height), dtype=np.uint16) - - np.random.seed(seed + num_resets) - - # generate rail array - # step 1: - # - generate a list of start and goal positions - # - use a min/max distance allowed as input for this - # - validate that start/goals are not placed too close to other start/goals - # - # step 2: (optional) - # - place random elements on rails array - # - for instance "train station", etc. - # - # step 3: - # - iterate over all [start, goal] pairs: - # - [first X pairs] - # - draw a rail from [start,goal] - # - draw either vertical or horizontal part first (randomly) - # - if rail crosses existing rail then validate new connection - # - if new connection is invalid turn 90 degrees to left/right - # - possibility that this fails to create a path to goal - # - on failure goto step1 and retry with seed+1 - # - [avoid crossing other start,goal positions] (optional) - # - # - [after X pairs] - # - find closest rail from start (Pa) - # - iterating outwards in a "circle" from start until an existing rail cell is hit - # - connect [start, Pa] - # - validate crossing rails - # - Do A* from Pa to find closest point on rail (Pb) to goal point - # - Basically normal A* but find point on rail which is closest to goal - # - since full path to goal is unlikely - # - connect [Pb, goal] - # - validate crossing rails - # - # step 4: (optional) - # - add more rails to map randomly - # - # step 5: - # - return transition map + list of [start, goal] points - # - - start_goal = [] - for _ in range(nr_start_goal): - sanity_max = 9000 - for _ in range(sanity_max): - start = (np.random.randint(0, width), np.random.randint(0, height)) - goal = (np.random.randint(0, height), np.random.randint(0, height)) - # check to make sure start,goal pos is empty? - if rail_array[goal] != 0 or rail_array[start] != 0: - continue - # check min/max distance - dist_sg = distance_on_rail(start, goal) - if dist_sg < min_dist: - continue - if dist_sg > max_dist: - continue - # check distance to existing points - sg_new = [start, goal] - - def check_all_dist(sg_new): - for sg in start_goal: - for i in range(2): - for j in range(2): - dist = distance_on_rail(sg_new[i], sg[j]) - if dist < 2: - # print("too close:", dist, sg_new[i], sg[j]) - return False - return True - - if check_all_dist(sg_new): - break - start_goal.append([start, goal]) - connect_rail(rail_trans, rail_array, start, goal) - - print("Created #", len(start_goal), "pairs") - # print(start_goal) - - return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) - return_rail.grid = rail_array - # TODO: return start_goal - return return_rail - - return generator - - -def rail_from_manual_specifications_generator(rail_spec): - """ - Utility to convert a rail given by manual specification as a map of tuples - (cell_type, rotation), to a transition map with the correct 16-bit - transitions specifications. - - Parameters - ------- - rail_spec : list of list of tuples - List (rows) of lists (columns) of tuples, each specifying a cell for - the RailEnv environment as (cell_type, rotation), with rotation being - clock-wise and in [0, 90, 180, 270]. - - Returns - ------- - function - Generator function that always returns a GridTransitionMap object with - the matrix of correct 16-bit bitmaps for each cell. - """ - - def generator(width, height, num_resets=0): - t_utils = RailEnvTransitions() - - height = len(rail_spec) - width = len(rail_spec[0]) - rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - - for r in range(height): - for c in range(width): - cell = rail_spec[r][c] - if cell[0] < 0 or cell[0] >= len(t_utils.transitions): - print("ERROR - invalid cell type=", cell[0]) - return [] - rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1])) - - return rail - - return generator - - -def rail_from_GridTransitionMap_generator(rail_map): - """ - Utility to convert a rail given by a GridTransitionMap map with the correct - 16-bit transitions specifications. - - Parameters - ------- - rail_map : GridTransitionMap object - GridTransitionMap object to return when the generator is called. - - Returns - ------- - function - Generator function that always returns the given `rail_map' object. - """ - - def generator(width, height, num_resets=0): - return rail_map - - return generator - - -def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): - """ - Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset. - - Parameters - ------- - list_of_filenames : list - List of filenames with the saved grids to load. - - Returns - ------- - function - Generator function that always returns the given `rail_map' object. - """ - - def generator(width, height, num_resets=0): - t_utils = RailEnvTransitions() - rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) - rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) - - if rail_map.grid.dtype == np.uint64: - rail_map.transitions = Grid8Transitions() - - return rail_map - - return generator - - -""" -def generate_rail_from_list_of_manual_specifications(list_of_specifications) - def generator(width, height, num_resets=0): - return generate_rail_from_manual_specifications(list_of_specifications) - - return generator -""" - - -def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): - """ - Dummy random level generator: - - fill in cells at random in [width-2, height-2] - - keep filling cells in among the unfilled ones, such that all transitions - are legit; if no cell can be filled in without violating some - transitions, pick one among those that can satisfy most transitions - (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were - incompatible. - - keep trying for a total number of insertions - (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the - board and try again from scratch. - - finally pad the border of the map with dead-ends to avoid border issues. - - Dead-ends are not allowed inside the grid, only at the border; however, if - no cell type can be inserted in a given cell (because of the neighboring - transitions), deadends are allowed if they solve the problem. This was - found to turn most un-genereatable levels into valid ones. - - Parameters - ------- - width : int - The width (number of cells) of the grid to generate. - height : int - The height (number of cells) of the grid to generate. - - Returns - ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. - """ - - def generator(width, height, num_resets=0): - t_utils = RailEnvTransitions() - - transition_probability = cell_type_relative_proportion - - transitions_templates_ = [] - transition_probabilities = [] - for i in range(len(t_utils.transitions) - 4): # don't include dead-ends - all_transitions = 0 - for dir_ in range(4): - trans = t_utils.get_transitions(t_utils.transitions[i], dir_) - all_transitions |= (trans[0] << 3) | \ - (trans[1] << 2) | \ - (trans[2] << 1) | \ - (trans[3]) - - template = [int(x) for x in bin(all_transitions)[2:]] - template = [0] * (4 - len(template)) + template - - # add all rotations - for rot in [0, 90, 180, 270]: - transitions_templates_.append((template, - t_utils.rotate_transition( - t_utils.transitions[i], - rot))) - transition_probabilities.append(transition_probability[i]) - template = [template[-1]] + template[:-1] - - def get_matching_templates(template): - ret = [] - for i in range(len(transitions_templates_)): - is_match = True - for j in range(4): - if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]: - is_match = False - break - if is_match: - ret.append((transitions_templates_[i][1], transition_probabilities[i])) - return ret - - MAX_INSERTIONS = (width - 2) * (height - 2) * 10 - MAX_ATTEMPTS_FROM_SCRATCH = 10 - - attempt_number = 0 - while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: - cells_to_fill = [] - rail = [] - for r in range(height): - rail.append([None] * width) - if r > 0 and r < height - 1: - cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)] - - num_insertions = 0 - while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: - # cell = random.sample(cells_to_fill, 1)[0] - cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]] - cells_to_fill.remove(cell) - row = cell[0] - col = cell[1] - - # look at its neighbors and see what are the possible transitions - # that can be chosen from, if any. - valid_template = [-1, -1, -1, -1] - - for el in [(0, 2, (-1, 0)), - (1, 3, (0, 1)), - (2, 0, (1, 0)), - (3, 1, (0, -1))]: # N, E, S, W - neigh_trans = rail[row + el[2][0]][col + el[2][1]] - if neigh_trans is not None: - # select transition coming from facing direction el[1] and - # moving to direction el[1] - max_bit = 0 - for k in range(4): - max_bit |= t_utils.get_transition(neigh_trans, k, el[1]) - - if max_bit: - valid_template[el[0]] = 1 - else: - valid_template[el[0]] = 0 - - possible_cell_transitions = get_matching_templates(valid_template) - - if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS - # no cell can be filled in without violating some transitions - # can a dead-end solve the problem? - if valid_template.count(1) == 1: - for k in range(4): - if valid_template[k] == 1: - rot = 0 - if k == 0: - rot = 180 - elif k == 1: - rot = 270 - elif k == 2: - rot = 0 - elif k == 3: - rot = 90 - - rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot) - num_insertions += 1 - - break - - else: - # can I get valid transitions by removing a single - # neighboring cell? - bestk = -1 - besttrans = [] - for k in range(4): - tmp_template = valid_template[:] - tmp_template[k] = -1 - possible_cell_transitions = get_matching_templates(tmp_template) - if len(possible_cell_transitions) > len(besttrans): - besttrans = possible_cell_transitions - bestk = k - - if bestk >= 0: - # Replace the corresponding cell with None, append it - # to cells to fill, fill in a transition in the current - # cell. - replace_row = row - 1 - replace_col = col - if bestk == 1: - replace_row = row - replace_col = col + 1 - elif bestk == 2: - replace_row = row + 1 - replace_col = col - elif bestk == 3: - replace_row = row - replace_col = col - 1 - - cells_to_fill.append((replace_row, replace_col)) - rail[replace_row][replace_col] = None - - possible_transitions, possible_probabilities = zip(*besttrans) - possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] - - rail[row][col] = np.random.choice(possible_transitions, - p=possible_probabilities) - num_insertions += 1 - - else: - print('WARNING: still nothing!') - rail[row][col] = int('0000000000000000', 2) - num_insertions += 1 - pass - - else: - possible_transitions, possible_probabilities = zip(*possible_cell_transitions) - possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] - - rail[row][col] = np.random.choice(possible_transitions, - p=possible_probabilities) - num_insertions += 1 - - if num_insertions == MAX_INSERTIONS: - # Failed to generate a valid level; try again for a number of times - attempt_number += 1 - else: - break - - if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: - print('ERROR: failed to generate level') - - # Finally pad the border of the map with dead-ends to avoid border issues; - # at most 1 transition in the neigh cell - for r in range(height): - # Check for transitions coming from [r][1] to WEST - max_bit = 0 - neigh_trans = rail[r][1] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & 1) - if max_bit: - rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270) - else: - rail[r][0] = int('0000000000000000', 2) - - # Check for transitions coming from [r][-2] to EAST - max_bit = 0 - neigh_trans = rail[r][-2] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) - if max_bit: - rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), - 90) - else: - rail[r][-1] = int('0000000000000000', 2) - - for c in range(width): - # Check for transitions coming from [1][c] to NORTH - max_bit = 0 - neigh_trans = rail[1][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) - if max_bit: - rail[0][c] = int('0010000000000000', 2) - else: - rail[0][c] = int('0000000000000000', 2) - - # Check for transitions coming from [-2][c] to SOUTH - max_bit = 0 - neigh_trans = rail[-2][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) - if max_bit: - rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180) - else: - rail[-1][c] = int('0000000000000000', 2) - - # For display only, wrong levels - for r in range(height): - for c in range(width): - if rail[r][c] is None: - rail[r][c] = int('0000000000000000', 2) - - tmp_rail = np.asarray(rail, dtype=np.uint16) - - return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - return_rail.grid = tmp_rail - return return_rail - - return generator +# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions +# from flatland.core.transition_map import GridTransitionMap class EnvAgentStatic(object): diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index ebaf905..5f6625a 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -167,26 +167,32 @@ class JupEditor(object): # get the direction index for the 2 transitions liTrans = [] for rcTrans in rc2Trans: + # gRCTrans - rcTrans gives an array of vector differences between our rcTrans + # and the 4 directions stored in gRCTrans. + # Where the vector difference is zero, we have a match... + # np.all detects where the whole row,col vector is zero. + # argwhere gives the index of the zero vector, ie the direction index iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1)) if len(iTrans) > 0: iTrans = iTrans[0][0] liTrans.append(iTrans) + # check that we have two transitions if len(liTrans) == 2: # Set the transition - # oEnv.rail.set_transition((*rcLast, iTransLast), iTrans, True) # does nothing - iValCell = env.rail.transitions.set_transition( - env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition) + env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], True) + # iValCell = env.rail.transitions.set_transition( + # env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition) # Also set the reverse transition - iValCell = env.rail.transitions.set_transition( - iValCell, - (liTrans[1] + 2) % 4, - (liTrans[0] + 2) % 4, - bTransition) + # iValCell = env.rail.transitions.set_transition( + # iValCell, + # (liTrans[1] + 2) % 4, # use the reversed outbound transition for inbound + # (liTrans[0] + 2) % 4, # use the reversed inbound transition for outbound + # bTransition) # Write the cell transition value back into the grid - env.rail.grid[tuple(rcMiddle)] = iValCell + # env.rail.grid[tuple(rcMiddle)] = iValCell rcHistory.pop(0) # remove the last-but-one diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index 55c229e..9ec0db0 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -5,7 +5,8 @@ import numpy as np from flatland.core.env_observation_builder import GlobalObsForRailEnv from flatland.core.transition_map import GridTransitionMap, Grid4Transitions -from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator +from flatland.envs.rail_env import RailEnv +from flatland.envs.generators import rail_from_GridTransitionMap_generator """Tests for `flatland` package.""" diff --git a/tests/test_environments.py b/tests/test_environments.py index b46bb38..a10fb06 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -2,11 +2,13 @@ # -*- coding: utf-8 -*- import numpy as np -from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator +from flatland.envs.rail_env import RailEnv +from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap from flatland.core.env_observation_builder import GlobalObsForRailEnv + """Tests for `flatland` package.""" diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 2ebfc46..69a7953 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -3,7 +3,8 @@ """Tests for `flatland` package.""" from flatland.core.transitions import RailEnvTransitions, Grid8Transitions -from flatland.envs.rail_env import validate_new_transition +# from flatland.envs.rail_env import validate_new_transition +from flatland.envs.env_utils import validate_new_transition import numpy as np -- GitLab