diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 16ec480f4f97ca8d200dd5e80b7e5d7c7ece2218..9ccef3fdf58aab56f1819a928477118fa8a467a5 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -3,7 +3,7 @@ import random import numpy as np from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool diff --git a/flatland/core/grid/__init__.py b/flatland/core/grid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py new file mode 100644 index 0000000000000000000000000000000000000000..5c09f0ac8ba86ed7987aefcf92a541f2ea5d1de4 --- /dev/null +++ b/flatland/core/grid/grid4.py @@ -0,0 +1,231 @@ +from enum import IntEnum + +import numpy as np + +from flatland.core.transitions import Transitions + + +class Grid4TransitionsEnum(IntEnum): + NORTH = 0 + EAST = 1 + SOUTH = 2 + WEST = 3 + + +class Grid4Transitions(Transitions): + """ + Grid4Transitions class derived from Transitions. + + Special case of `Transitions' over a 2D-grid (FlatLand). + Transitions are possible to neighboring cells on the grid if allowed. + GridTransitions keeps track of valid transitions supplied as `transitions' + list, each represented as a bitmap of 16 bits. + + Whether a transition is allowed or not depends on which direction an agent + inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which + direction the agent wants to move to + (North, East, South, West, relative to the cell). + Each transition (orientation, direction) + can be allowed (1) or forbidden (0). + + For example, in case of no diagonal transitions on the grid, the 16 bits + of the transition bitmaps are organized in 4 blocks of 4 bits each, the + direction that the agent is facing. + E.g., the most-significant 4-bits represent the possible movements (NESW) + if the agent is facing North, etc... + + agent's direction: North East South West + agent's allowed movements: [nesw] [nesw] [nesw] [nesw] + example: 1000 0000 0010 0000 + + In the example, the agent can move from North to South and viceversa. + """ + + def __init__(self, transitions): + self.transitions = transitions + self.sDirs = "NESW" + self.lsDirs = list(self.sDirs) + + # row,col delta for each direction + self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) + + # These bits represent all the possible dead ends + self.maskDeadEnds = 0b0010000110000100 + + def get_type(self): + return np.uint16 + + def get_transitions(self, cell_transition, orientation): + """ + Get the 4 possible transitions ((N,E,S,W), 4 elements tuple + if no diagonal transitions allowed) available for an agent oriented + in direction `orientation' and inside a cell with + transitions `cell_transition'. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + + Returns + ------- + tuple + List of the validity of transitions in the cell. + + """ + bits = (cell_transition >> ((3 - orientation) * 4)) + return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1) + + def set_transitions(self, cell_transition, orientation, new_transitions): + """ + Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple + if no diagonal transitions allowed) available for an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition'. A new `cell_transition' is returned with + the specified bits replaced by `new_transitions'. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + new_transitions : tuple + Tuple of new transitions validitiy for the cell. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4)) + negmask = ~mask + + new_transitions = \ + (new_transitions[0] & 1) << 3 | \ + (new_transitions[1] & 1) << 2 | \ + (new_transitions[2] & 1) << 1 | \ + (new_transitions[3] & 1) + + cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4)) + + return cell_transition + + def get_transition(self, cell_transition, orientation, direction): + """ + Get the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + + Returns + ------- + int + Validity of the requested transition: 0/1 allowed/not allowed. + + """ + return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 + + def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): + """ + Set the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + new_transition : int + Validity of the requested transition: 0/1 allowed/not allowed. + remove_deadends -- boolean, default False + remove all deadend transitions. + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + if new_transition: + cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) + else: + cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) + + if remove_deadends: + cell_transition = self.remove_deadends(cell_transition) + + return cell_transition + + def rotate_transition(self, cell_transition, rotation=0): + """ + Clockwise-rotate a 16-bit transition bitmap by + rotation={0, 90, 180, 270} degrees. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + rotation : int + Angle by which to clock-wise rotate the transition bits in + `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees. + + Returns + ------- + int + An updated bitmap that replaces the original transitions bits + with the equivalent bitmap after rotation. + + """ + # Rotate the individual bits in each block + value = cell_transition + rotation = rotation // 90 + for i in range(4): + block_tuple = self.get_transitions(value, i) + block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)] + value = self.set_transitions(value, i, block_tuple) + + # Rotate the 4-bits blocks + value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4)) + + cell_transition = value + return cell_transition + + def get_direction_enum(self) -> IntEnum: + return Grid4TransitionsEnum + + def has_deadend(self, cell_transition): + """ + Checks if one entry can only by exited by a turn-around. + """ + if cell_transition & self.maskDeadEnds > 0: + return True + else: + return False + + def remove_deadends(self, cell_transition): + """ + Remove all turn-arounds (e.g. N-S, S-N, E-W,...). + """ + cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff + return cell_transition diff --git a/flatland/core/grid/grid8.py b/flatland/core/grid/grid8.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba379a5ecb4099de999905d34daf91ccccac640 --- /dev/null +++ b/flatland/core/grid/grid8.py @@ -0,0 +1,203 @@ +from enum import IntEnum + +import numpy as np + +from flatland.core.transitions import Transitions + + +class Grid8TransitionsEnum(IntEnum): + NORTH = 0 + NORTH_EAST = 1 + EAST = 2 + SOUTH_EAST = 3 + SOUTH = 4 + SOUTH_WEST = 5 + WEST = 6 + NORTH_WEST = 7 + + +class Grid8Transitions(Transitions): + """ + Grid8Transitions class derived from Transitions. + + Special case of `Transitions' over a 2D-grid (FlatLand). + Transitions are possible to neighboring cells on the grid if allowed. + GridTransitions keeps track of valid transitions supplied as `transitions' + list, each represented as a bitmap of 64 bits. + + 0=North, 1=North-East, etc. + + """ + + def __init__(self, transitions): + self.transitions = transitions + + def get_type(self): + return np.uint64 + + def get_transitions(self, cell_transition, orientation): + """ + Get the 8 possible transitions. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + + Returns + ------- + tuple + List of the validity of transitions in the cell. + + """ + bits = (np.uint64(cell_transition) >> np.uint64((7 - orientation) * 8)) + cell_transition = ( + (bits >> np.uint64(7)) & np.uint64(1), + (bits >> np.uint64(6)) & np.uint64(1), + (bits >> np.uint64(5)) & np.uint64(1), + (bits >> np.uint64(4)) & np.uint64(1), + (bits >> np.uint64(3)) & np.uint64(1), + (bits >> np.uint64(2)) & np.uint64(1), + (bits >> np.uint64(1)) & np.uint64(1), + bits & np.uint64(1)) + + return cell_transition + + def set_transitions(self, cell_transition, orientation, new_transitions): + """ + Set the possible transitions. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + new_transitions : tuple + Tuple of new transitions validitiy for the cell. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8)) + negmask = ~mask + + new_transitions = \ + (int(new_transitions[0]) & 1) << 7 | \ + (int(new_transitions[1]) & 1) << 6 | \ + (int(new_transitions[2]) & 1) << 5 | \ + (int(new_transitions[3]) & 1) << 4 | \ + (int(new_transitions[4]) & 1) << 3 | \ + (int(new_transitions[5]) & 1) << 2 | \ + (int(new_transitions[6]) & 1) << 1 | \ + (int(new_transitions[7]) & 1) + + cell_transition = (int(cell_transition) & negmask) | (new_transitions << ((7 - orientation) * 8)) + + return cell_transition + + def get_transition(self, cell_transition, orientation, direction): + """ + Get the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + + Returns + ------- + int + Validity of the requested transition: 0/1 allowed/not allowed. + + """ + return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1 + + def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): + + """ + Set the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + new_transition : int + Validity of the requested transition: 0/1 allowed/not allowed. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + if new_transition: + cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) + else: + cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) + + return cell_transition + + def rotate_transition(self, cell_transition, rotation=0): + """ + Clockwise-rotate a 64-bit transition bitmap by + rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + rotation : int + Angle by which to clock-wise rotate the transition bits in + `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180, + 225, 270, 315} degrees. + + Returns + ------- + int + An updated bitmap that replaces the original transitions bits + with the equivalent bitmap after rotation. + + """ + # TODO: WARNING: this part of the function has never been tested! + + # Rotate the individual bits in each block + value = cell_transition + rotation = rotation // 45 + for i in range(8): + block_tuple = self.get_transitions(value, i) + block_tuple = block_tuple[rotation:] + block_tuple[:rotation] + value = self.set_transitions(value, i, block_tuple) + + # Rotate the 8bits blocks + value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8)) + + cell_transition = value + + return cell_transition + + def get_direction_enum(self) -> IntEnum: + return Grid8TransitionsEnum diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py new file mode 100644 index 0000000000000000000000000000000000000000..c043b42f1ca84ba9d0f7a68f5e18a192ff374d7a --- /dev/null +++ b/flatland/core/grid/rail_env_grid.py @@ -0,0 +1,111 @@ +from flatland.core.grid.grid4 import Grid4Transitions + + +class RailEnvTransitions(Grid4Transitions): + """ + Special case of `GridTransitions' over a 2D-grid, with a pre-defined set + of transitions mimicking the types of real Swiss rail connections. + + -------------------------------------------------------------------------- + + As no diagonal transitions are allowed in the RailEnv environment, the + possible transitions for RailEnv from a cell to its neighboring ones + are represented over 16 bits. + + The 16 bits are organized in 4 blocks of 4 bits each, the direction that + the agent is facing. + E.g., the most-significant 4-bits represent the possible movements (NESW) + if the agent is facing North, etc... + + agent's direction: North East South West + agent's allowed movements: [nesw] [nesw] [nesw] [nesw] + example: 1000 0000 0010 0000 + + In the example, the agent can move from North to South and viceversa. + """ + + # Contains the basic transitions; + # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions. + transition_list = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip + int('1100110000110011', 2), # Case 5 - double slip + int('0101001000000010', 2), # Case 6 - symmetrical + int('0010000000000000', 2), # Case 7 - dead end + int('0100000000000010', 2), # Case 1b (8) - simple turn right + int('0001001000000000', 2), # Case 1c (9) - simple turn left + int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored + + def __init__(self): + super(RailEnvTransitions, self).__init__( + transitions=self.transition_list + ) + + # create this to make validation faster + self.transitions_all = set() + for index, trans in enumerate(self.transitions): + self.transitions_all.add(trans) + if index in (2, 4, 6, 7, 8, 9, 10): + for _ in range(3): + trans = self.rotate_transition(trans, rotation=90) + self.transitions_all.add(trans) + elif index in (1, 5): + trans = self.rotate_transition(trans, rotation=90) + self.transitions_all.add(trans) + + def print(self, cell_transition): + print(" NESW") + print("N", format(cell_transition >> (3 * 4) & 0xF, '04b')) + print("E", format(cell_transition >> (2 * 4) & 0xF, '04b')) + print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) + print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) + + def repr(self, cell_transition, version=0): + """ + Provide a string representation of the cell transitions. + This class doesn't represent an individual cell, + but a way of interpreting the contents of a cell. + So using the ad hoc name repr rather than __repr__. + """ + # binary format string without leading 0b + sbinTrans = format(cell_transition, "#018b")[2:] + if version == 0: + sRepr = " ".join([ + "{}:{}".format(sDir, sbinTrans[i:(i + 4)]) + for i, sDir in + zip( + range(0, len(sbinTrans), 4), + self.lsDirs)]) # NESW + return sRepr + + if version == 1: + lsRepr = [] + for iDirIn in range(0, 4): + sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)] + if sDirTrans == "0000": + continue + sDirsOut = [ + self.lsDirs[iDirOut] + for iDirOut in range(0, 4) + if sDirTrans[iDirOut] == "1"] + lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut)) + + return ", ".join(lsRepr) + + def is_valid(self, cell_transition): + """ + Checks if a cell transition is a valid cell setup. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + + Returns + ------- + Boolean + True or False + """ + return cell_transition in self.transitions_all diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 43b9a72ae7906c8a7dce58a2b58155a99ff760ae..6c0b92a7c6d45187dae5158fb0a81c9fab2d7280 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -6,7 +6,7 @@ import numpy as np from importlib_resources import path from numpy import array -from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions +from flatland.core.grid.grid4 import Grid4Transitions class TransitionMap: @@ -73,7 +73,7 @@ class TransitionMap: Returns ------- - int or float (depending on derived class) + int or float (depending on Transitions used) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) @@ -95,7 +95,7 @@ class TransitionMap: Index of the transition to probe, as index in the tuple returned by get_transitions(). e.g., the NESW direction of movement, for agents on a grid. - new_transition : int or float (depending on derived class) + new_transition : int or float (depending on Transitions used) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) @@ -130,10 +130,7 @@ class GridTransitionMap(TransitionMap): self.height = height self.transitions = transitions - if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions): - self.grid = np.ndarray((height, width), dtype=np.uint16) - elif isinstance(self.transitions, Grid8Transitions): - self.grid = np.ndarray((height, width), dtype=np.uint64) + self.grid = np.zeros((height, width), dtype=self.transitions.get_type()) def get_transitions(self, cell_id): """ @@ -156,14 +153,12 @@ class GridTransitionMap(TransitionMap): List of the validity of transitions in the cell. """ + assert len(cell_id) in (2, 3), \ + 'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.' if len(cell_id) == 3: return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2]) elif len(cell_id) == 2: return self.grid[cell_id[0]][cell_id[1]] - else: - print('GridTransitionMap.get_transitions() ERROR: \ - wrong cell_id tuple.') - return () def set_transitions(self, cell_id, new_transitions): """ @@ -182,15 +177,14 @@ class GridTransitionMap(TransitionMap): Tuple of new transitions validitiy for the cell. """ + assert len(cell_id) in (2, 3), \ + 'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.' if len(cell_id) == 3: self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2], new_transitions) elif len(cell_id) == 2: self.grid[cell_id[0]][cell_id[1]] = new_transitions - else: - print('GridTransitionMap.get_transitions() ERROR: \ - wrong cell_id tuple.') def get_transition(self, cell_id, transition_index): """ @@ -210,15 +204,14 @@ class GridTransitionMap(TransitionMap): Returns ------- - int or float (depending on derived class) + int or float (depending on Transitions used in the ) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) """ - if len(cell_id) != 3: - print('GridTransitionMap.get_transition() ERROR: \ - wrong cell_id tuple.') - return () + + assert len(cell_id) == 3, \ + 'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.' return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index) def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False): @@ -236,15 +229,13 @@ class GridTransitionMap(TransitionMap): Index of the transition to probe, as index in the tuple returned by get_transitions(). e.g., the NESW direction of movement, for agents on a grid. - new_transition : int or float (depending on derived class) + new_transition : int or float (depending on Transitions used in the map.) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) """ - if len(cell_id) != 3: - print('GridTransitionMap.set_transition() ERROR: \ - wrong cell_id tuple.') - return + assert len(cell_id) == 3, \ + 'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.' self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition( self.grid[cell_id[0]][cell_id[1]], cell_id[2], @@ -264,7 +255,7 @@ class GridTransitionMap(TransitionMap): """ np.save(filename, self.grid) - def load_transition_map(self, package, resource, override_gridsize=True): + def load_transition_map(self, package, resource): """ Load the transitions grid from `filename' (npy format). The load function only updates the transitions grid, and possibly width and height, but the object has to be @@ -289,28 +280,9 @@ class GridTransitionMap(TransitionMap): new_height = new_grid.shape[0] new_width = new_grid.shape[1] - if override_gridsize: - self.width = new_width - self.height = new_height - self.grid = new_grid - - else: - if new_grid.dtype == np.uint16: - self.grid = np.zeros((self.height, self.width), dtype=np.uint16) - elif new_grid.dtype == np.uint64: - self.grid = np.zeros((self.height, self.width), dtype=np.uint64) - - self.grid[0:min(self.height, new_height), - 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), - 0:min(self.width, new_width)] - - def is_cell_valid(self, rcPos): - cell_transition = self.grid[tuple(rcPos)] - - if not self.transitions.is_valid(cell_transition): - return False - else: - return True + self.width = new_width + self.height = new_height + self.grid = new_grid def cell_neighbours_valid(self, rcPos, check_this_cell=False): """ @@ -364,9 +336,6 @@ class GridTransitionMap(TransitionMap): return True - def cell_repr(self, rcPos): - return self.transitions.repr(self.get_transitions(rcPos)) - # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids # (most general implementation) or to make Grid-class specific methods for # slicing over the 3 dimensions? I'd say both perhaps. diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 6c38a39cea696b9412bca35062978a1d8d19eef0..29b57c40faf567e6a9aa4b679df6af6fcf0909ba 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -5,8 +5,6 @@ possible transitions over a 2D grid. """ from enum import IntEnum -import numpy as np - class Transitions: """ @@ -17,6 +15,9 @@ class Transitions: `orientation' and moving into direction `direction') """ + def get_type(self): + raise NotImplementedError() + def get_transitions(self, cell_transition, orientation): """ Return a tuple of transitions available in a cell specified by @@ -132,529 +133,3 @@ class Transitions: def get_direction_enum(self) -> IntEnum: raise NotImplementedError() - - -class Grid4TransitionsEnum(IntEnum): - NORTH = 0 - EAST = 1 - SOUTH = 2 - WEST = 3 - - -class Grid4Transitions(Transitions): - """ - Grid4Transitions class derived from Transitions. - - Special case of `Transitions' over a 2D-grid (FlatLand). - Transitions are possible to neighboring cells on the grid if allowed. - GridTransitions keeps track of valid transitions supplied as `transitions' - list, each represented as a bitmap of 16 bits. - - Whether a transition is allowed or not depends on which direction an agent - inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which - direction the agent wants to move to - (North, East, South, West, relative to the cell). - Each transition (orientation, direction) - can be allowed (1) or forbidden (0). - - For example, in case of no diagonal transitions on the grid, the 16 bits - of the transition bitmaps are organized in 4 blocks of 4 bits each, the - direction that the agent is facing. - E.g., the most-significant 4-bits represent the possible movements (NESW) - if the agent is facing North, etc... - - agent's direction: North East South West - agent's allowed movements: [nesw] [nesw] [nesw] [nesw] - example: 1000 0000 0010 0000 - - In the example, the agent can move from North to South and viceversa. - """ - - def __init__(self, transitions): - self.transitions = transitions - self.sDirs = "NESW" - self.lsDirs = list(self.sDirs) - - # row,col delta for each direction - self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) - - def get_transitions(self, cell_transition, orientation): - """ - Get the 4 possible transitions ((N,E,S,W), 4 elements tuple - if no diagonal transitions allowed) available for an agent oriented - in direction `orientation' and inside a cell with - transitions `cell_transition'. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - - Returns - ------- - tuple - List of the validity of transitions in the cell. - - """ - bits = (cell_transition >> ((3 - orientation) * 4)) - return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1) - - def set_transitions(self, cell_transition, orientation, new_transitions): - """ - Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple - if no diagonal transitions allowed) available for an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition'. A new `cell_transition' is returned with - the specified bits replaced by `new_transitions'. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - new_transitions : tuple - Tuple of new transitions validitiy for the cell. - - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4)) - negmask = ~mask - - new_transitions = \ - (new_transitions[0] & 1) << 3 | \ - (new_transitions[1] & 1) << 2 | \ - (new_transitions[2] & 1) << 1 | \ - (new_transitions[3] & 1) - - cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4)) - - return cell_transition - - def get_transition(self, cell_transition, orientation, direction): - """ - Get the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - - Returns - ------- - int - Validity of the requested transition: 0/1 allowed/not allowed. - - """ - return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 - - def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): - """ - Set the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - new_transition : int - Validity of the requested transition: 0/1 allowed/not allowed. - remove_deadends -- boolean, default False - remove all deadend transitions. - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - if new_transition: - cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) - else: - cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) - - if remove_deadends: - cell_transition = self.remove_deadends(cell_transition) - - return cell_transition - - def rotate_transition(self, cell_transition, rotation=0): - """ - Clockwise-rotate a 16-bit transition bitmap by - rotation={0, 90, 180, 270} degrees. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - rotation : int - Angle by which to clock-wise rotate the transition bits in - `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees. - - Returns - ------- - int - An updated bitmap that replaces the original transitions bits - with the equivalent bitmap after rotation. - - """ - # Rotate the individual bits in each block - value = cell_transition - rotation = rotation // 90 - for i in range(4): - block_tuple = self.get_transitions(value, i) - block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)] - value = self.set_transitions(value, i, block_tuple) - - # Rotate the 4-bits blocks - value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4)) - - cell_transition = value - return cell_transition - - def get_direction_enum(self) -> IntEnum: - return Grid4TransitionsEnum - - -class Grid8TransitionsEnum(IntEnum): - NORTH = 0 - NORTH_EAST = 1 - EAST = 2 - SOUTH_EAST = 3 - SOUTH = 4 - SOUTH_WEST = 5 - WEST = 6 - NORTH_WEST = 7 - - -class Grid8Transitions(Transitions): - """ - Grid8Transitions class derived from Transitions. - - Special case of `Transitions' over a 2D-grid (FlatLand). - Transitions are possible to neighboring cells on the grid if allowed. - GridTransitions keeps track of valid transitions supplied as `transitions' - list, each represented as a bitmap of 64 bits. - - 0=North, 1=North-East, etc. - - """ - - def __init__(self, transitions): - self.transitions = transitions - - def get_transitions(self, cell_transition, orientation): - """ - Get the 8 possible transitions. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - - Returns - ------- - tuple - List of the validity of transitions in the cell. - - """ - bits = (cell_transition >> ((7 - orientation) * 8)) - cell_transition = ( - (bits >> 7) & 1, - (bits >> 6) & 1, - (bits >> 5) & 1, - (bits >> 4) & 1, - (bits >> 3) & 1, - (bits >> 2) & 1, - (bits >> 1) & 1, - (bits) & 1) - - return cell_transition - - def set_transitions(self, cell_transition, orientation, new_transitions): - """ - Set the possible transitions. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - new_transitions : tuple - Tuple of new transitions validitiy for the cell. - - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8)) - negmask = ~mask - - new_transitions = \ - (new_transitions[0] & 1) << 7 | \ - (new_transitions[1] & 1) << 6 | \ - (new_transitions[2] & 1) << 5 | \ - (new_transitions[3] & 1) << 4 | \ - (new_transitions[4] & 1) << 3 | \ - (new_transitions[5] & 1) << 2 | \ - (new_transitions[6] & 1) << 1 | \ - (new_transitions[7] & 1) - - cell_transition = (cell_transition & negmask) | (new_transitions << ((7 - orientation) * 8)) - - return cell_transition - - def get_transition(self, cell_transition, orientation, direction): - """ - Get the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - - Returns - ------- - int - Validity of the requested transition: 0/1 allowed/not allowed. - - """ - return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1 - - def set_transition(self, cell_transition, orientation, direction, - new_transition): - """ - Set the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - new_transition : int - Validity of the requested transition: 0/1 allowed/not allowed. - - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - if new_transition: - cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) - else: - cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) - - return cell_transition - - def rotate_transition(self, cell_transition, rotation=0): - """ - Clockwise-rotate a 64-bit transition bitmap by - rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - rotation : int - Angle by which to clock-wise rotate the transition bits in - `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180, - 225, 270, 315} degrees. - - Returns - ------- - int - An updated bitmap that replaces the original transitions bits - with the equivalent bitmap after rotation. - - """ - # TODO: WARNING: this part of the function has never been tested! - - # Rotate the individual bits in each block - value = cell_transition - rotation = rotation // 45 - for i in range(8): - block_tuple = self.get_transitions(value, i) - block_tuple = block_tuple[rotation:] + block_tuple[:rotation] - value = self.set_transitions(value, i, block_tuple) - - # Rotate the 8bits blocks - value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8)) - - cell_transition = value - - return cell_transition - - def get_direction_enum(self) -> IntEnum: - return Grid8TransitionsEnum - - -class RailEnvTransitions(Grid4Transitions): - """ - Special case of `GridTransitions' over a 2D-grid, with a pre-defined set - of transitions mimicking the types of real Swiss rail connections. - - -------------------------------------------------------------------------- - - As no diagonal transitions are allowed in the RailEnv environment, the - possible transitions for RailEnv from a cell to its neighboring ones - are represented over 16 bits. - - The 16 bits are organized in 4 blocks of 4 bits each, the direction that - the agent is facing. - E.g., the most-significant 4-bits represent the possible movements (NESW) - if the agent is facing North, etc... - - agent's direction: North East South West - agent's allowed movements: [nesw] [nesw] [nesw] [nesw] - example: 1000 0000 0010 0000 - - In the example, the agent can move from North to South and viceversa. - """ - - """ - transitions[] is indexed by case type/id, and returns the 4x4-bit [NESW] - transitions available as a function of the agent's orientation - (north, east, south, west) - """ - - transition_list = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip - int('1100110000110011', 2), # Case 5 - double slip - int('0101001000000010', 2), # Case 6 - symmetrical - int('0010000000000000', 2), # Case 7 - dead end - int('0100000000000010', 2), # Case 1b (8) - simple turn right - int('0001001000000000', 2), # Case 1c (9) - simple turn left - int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored - - def __init__(self): - super(RailEnvTransitions, self).__init__( - transitions=self.transition_list - ) - - # These bits represent all the possible dead ends - self.maskDeadEnds = 0b0010000110000100 - - # create this to make validation faster - self.transitions_all = set() - for index, trans in enumerate(self.transitions): - self.transitions_all.add(trans) - if index in (2, 4, 6, 7, 8, 9, 10): - for _ in range(3): - trans = self.rotate_transition(trans, rotation=90) - self.transitions_all.add(trans) - elif index in (1, 5): - trans = self.rotate_transition(trans, rotation=90) - self.transitions_all.add(trans) - - def print(self, cell_transition): - print(" NESW") - print("N", format(cell_transition >> (3 * 4) & 0xF, '04b')) - print("E", format(cell_transition >> (2 * 4) & 0xF, '04b')) - print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) - print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) - - def repr(self, cell_transition, version=0): - """ - Provide a string representation of the cell transitions. - This class doesn't represent an individual cell, - but a way of interpreting the contents of a cell. - So using the ad hoc name repr rather than __repr__. - """ - # binary format string without leading 0b - sbinTrans = format(cell_transition, "#018b")[2:] - if version == 0: - sRepr = " ".join([ - "{}:{}".format(sDir, sbinTrans[i:(i + 4)]) - for i, sDir in - zip( - range(0, len(sbinTrans), 4), - self.lsDirs)]) # NESW - return sRepr - - if version == 1: - lsRepr = [] - for iDirIn in range(0, 4): - sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)] - if sDirTrans == "0000": - continue - sDirsOut = [ - self.lsDirs[iDirOut] - for iDirOut in range(0, 4) - if sDirTrans[iDirOut] == "1"] - lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut)) - - return ", ".join(lsRepr) - - def is_valid(self, cell_transition): - """ - Checks if a cell transition is a valid cell setup. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - - Returns - ------- - Boolean - True or False - """ - return cell_transition in self.transitions_all - - def has_deadend(self, cell_transition): - if cell_transition & self.maskDeadEnds > 0: - return True - else: - return False - - def remove_deadends(self, cell_transition): - cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff - return cell_transition diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index cc4a0015601d0f5820cccb25b6076a0a92f67915..19da8946864245343a09857d8b2fbe968f47ba35 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -7,7 +7,7 @@ a GridTransitionMap object. import numpy as np -from flatland.core.transitions import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum def get_direction(pos1, pos2) -> Grid4TransitionsEnum: diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index ca14667424d2c93d1466e3b7e96c2e5c1fbd41e5..fa5cdccc91fd837ac7b034564dffb5f909b81a1e 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,7 +1,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -154,7 +154,7 @@ def rail_from_manual_specifications_generator(rail_spec): Parameters ------- rail_spec : list of list of tuples - List (rows) of lists (columns) of tuples, each specifying a cell for + List (rows) of lists (columns) of tuples, each specifying a rail_spec_of_cell for the RailEnv environment as (cell_type, rotation), with rotation being clock-wise and in [0, 90, 180, 270]. @@ -162,23 +162,27 @@ def rail_from_manual_specifications_generator(rail_spec): ------- function Generator function that always returns a GridTransitionMap object with - the matrix of correct 16-bit bitmaps for each cell. + the matrix of correct 16-bit bitmaps for each rail_spec_of_cell. """ def generator(width, height, num_agents, num_resets=0): - t_utils = RailEnvTransitions() + rail_env_transitions = RailEnvTransitions() height = len(rail_spec) width = len(rail_spec[0]) - rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions) for r in range(height): for c in range(width): - cell = rail_spec[r][c] - if cell[0] < 0 or cell[0] >= len(t_utils.transitions): - print("ERROR - invalid cell type=", cell[0]) + rail_spec_of_cell = rail_spec[r][c] + index_basic_type_of_cell_ = rail_spec_of_cell[0] + rotation_cell_ = rail_spec_of_cell[1] + if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(rail_env_transitions.transitions): + print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_) return [] - rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1])) + basic_type_of_cell_ = rail_env_transitions.transitions[index_basic_type_of_cell_] + effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_) + rail.set_transitions((r, c), effective_transition_cell) agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( rail, diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a7f91f1439f98bc2627700903b0175486f619749..ecb0697899ab64ce5bb455c33db925b8561bd14a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -6,7 +6,7 @@ from collections import deque import numpy as np from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.transitions import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.env_utils import coordinate_to_position diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index c7cb8fa12ea2a18974d3d968c8f04128d0daf624..6098f434b46bf96648136aa0b6060b8881629be1 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -28,7 +28,7 @@ enable_windows_cairo_support() from cairosvg import svg2png # noqa: E402 from screeninfo import get_monitors # noqa: E402 -from flatland.core.transitions import RailEnvTransitions # noqa: E402 +from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402 class PILGL(GraphicsLayer): diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index 249b4cb52c6108d48035a267022e78bde6f1cc91..b2e0284407d52cbbc53ae74c0a744a457d7fdbeb 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -3,7 +3,7 @@ import re import svgutils -from flatland.core.transitions import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions class SVG(object): diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py deleted file mode 100644 index f0d86292ce926147017bab5e4d777019bbcfb143..0000000000000000000000000000000000000000 --- a/tests/test_env_edit.py +++ /dev/null @@ -1,11 +0,0 @@ -from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.rail_env import RailEnv - - -def test_load_env(): - env = RailEnv(10, 10) - env.load_resource('env_data.tests', 'test-10x10.mpk') - - agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) - env.add_agent_static(agent_static) - assert env.get_num_agents() == 1 diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py new file mode 100644 index 0000000000000000000000000000000000000000..5117b12af72be948bc806940169e330d254fdb9b --- /dev/null +++ b/tests/test_flatland_core_transition_map.py @@ -0,0 +1,21 @@ +from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum +from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum +from flatland.core.transition_map import GridTransitionMap + + +def test_grid4_set_transitions(): + grid4_map = GridTransitionMap(2, 2, Grid4Transitions([])) + assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0) + grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1) + assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (1, 0, 0, 0) + grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0) + assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0) + + +def test_grid8_set_transitions(): + grid8_map = GridTransitionMap(2, 2, Grid8Transitions([])) + assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0) + grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1) + assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0) + grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0) + assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0) diff --git a/tests/test_transitions.py b/tests/test_flatland_core_transitions.py similarity index 51% rename from tests/test_transitions.py rename to tests/test_flatland_core_transitions.py index 9d02553e7f6d2c829026498ea9be804d6d1e766f..048520c17eeddf4d1f4a4c6beeb49427887b77f4 100644 --- a/tests/test_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -4,10 +4,102 @@ """Tests for `flatland` package.""" import numpy as np -from flatland.core.transitions import RailEnvTransitions, Grid8Transitions +from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.grid8 import Grid8Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.env_utils import validate_new_transition +# remove whitespace in string; keep whitespace below for easier reading +def rw(s): + return s.replace(" ", "") + + +def test_rotate_railenv_transition(): + rail_env_transitions = RailEnvTransitions() + + # TODO test all cases + transition_cycles = [ + # empty cell - Case 0 + [int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2), + int('0000000000000000', 2)], + # Case 1 - straight + # | + # | + # | + [int(rw('1000 0000 0010 0000'), 2), int(rw('0000 0100 0000 0001'), 2)], + # Case 1b (8) - simple turn right + # _ + # | + # | + [ + int(rw('0100 0000 0000 0010'), 2), + int(rw('0001 0010 0000 0000'), 2), + int(rw('0000 1000 0001 0000'), 2), + int(rw('0000 0000 0100 1000'), 2), + ], + # Case 1c (9) - simple turn left + # _ + # | + # | + + # int('0001001000000000', 2),\ # noqa: E800 + + # Case 2 - simple left switch + # _ _| + # | + # | + [ + int(rw('1001 0010 0010 0000'), 2), + int(rw('0000 1100 0001 0001'), 2), + int(rw('1000 0000 0110 1000'), 2), + int(rw('0100 0100 0000 0011'), 2), + ], + # Case 2b (10) - simple right switch + # | + # | + # | + + # int('1100000000100010', 2) \ # noqa: E800 + + # Case 3 - diamond drossing + # int('1000010000100001', 2), \ # noqa: E800 + # Case 4 - single slip + # int('1001011000100001', 2), \ # noqa: E800 + # Case 5 - double slip + # int('1100110000110011', 2), \ # noqa: E800 + # Case 6 - symmetrical + # int('0101001000000010', 2), \ # noqa: E800 + + # Case 7 - dead end + # + # + # | + [ + int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2), + ], + ] + + for index, cycle in enumerate(transition_cycles): + for i in range(4): + actual_transition = rail_env_transitions.rotate_transition(cycle[0], i * 90) + expected_transition = cycle[i % len(cycle)] + try: + assert actual_transition == expected_transition, \ + "Case {}: rotate_transition({}, {}) should equal {} but was {}.".format( + i, cycle[0], i, expected_transition, actual_transition) + except Exception as e: + print("expected:") + rail_env_transitions.print(expected_transition) + print("actual:") + rail_env_transitions.print(actual_transition) + + raise e + + def test_is_valid_railenv_transitions(): rail_env_trans = RailEnvTransitions() transition_list = rail_env_trans.transitions @@ -120,7 +212,37 @@ def test_diagonal_transitions(): # Allowing transition from north to southwest: Facing south, going SW north_southwest_transition = \ - diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0)) + diagonal_trans_env.set_transitions(0, 4, (0, 0, 0, 0, 0, 1, 0, 0)) assert (diagonal_trans_env.rotate_transition( south_northeast_transition, 180) == north_southwest_transition) + + +def test_rail_env_has_deadend(): + deadends = set([int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2)]) + ret = RailEnvTransitions() + transitions_all = ret.transitions_all + for t in transitions_all: + expected_has_deadend = t in deadends + actual_had_deadend = ret.has_deadend(t) + assert actual_had_deadend == expected_has_deadend, \ + "{} should be deadend = {}, actual = {}".format(t, ) + + +def test_rail_env_remove_deadend(): + ret = Grid4Transitions([]) + rail_env_deadends = set([int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2)]) + for t in rail_env_deadends: + expected_has_deadend = 0 + actual_had_deadend = ret.remove_deadends(t) + assert actual_had_deadend == expected_has_deadend, \ + "{} should be deadend = {}, actual = {}".format(t, ) + + assert ret.remove_deadends(int(rw('0010 0001 1000 0100'), 2)) == 0 + assert ret.remove_deadends(int(rw('0010 0001 1000 0110'), 2)) == int(rw('0000 0000 0000 0010'), 2) diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py index 25952031c1384430421492244e096dbd29fc557e..49b619a159870c9105137ed41a8b55aa1dd19e36 100644 --- a/tests/test_flatland_envs_env_utils.py +++ b/tests/test_flatland_envs_env_utils.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from flatland.core.transitions import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction depth_to_test = 5 diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index f34829d7f5a94230473e6394bbafc0c8c9ae78c1..16850672c1f5a479f1cf86ca3e3f6c547c4a4ca7 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -3,8 +3,8 @@ import numpy as np +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.transition_map import GridTransitionMap, Grid4Transitions -from flatland.core.transitions import Grid4TransitionsEnum from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv diff --git a/tests/test_environments.py b/tests/test_flatland_envs_rail_env.py similarity index 95% rename from tests/test_environments.py rename to tests/test_flatland_envs_rail_env.py index 79160ce549155ea7dc8e905f015642d7b6ed5723..3a50c482176cc1352a3587edcb92704525c58c08 100644 --- a/tests/test_environments.py +++ b/tests/test_flatland_envs_rail_env.py @@ -2,9 +2,11 @@ # -*- coding: utf-8 -*- import numpy as np +from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import Grid4Transitions, RailEnvTransitions from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import GlobalObsForRailEnv @@ -13,6 +15,15 @@ from flatland.envs.rail_env import RailEnv """Tests for `flatland` package.""" +def test_load_env(): + env = RailEnv(10, 10) + env.load_resource('env_data.tests', 'test-10x10.mpk') + + agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) + env.add_agent_static(agent_static) + assert env.get_num_agents() == 1 + + def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), diff --git a/tests/test_rendertools.py b/tests/test_flatland_utils_rendertools.py similarity index 96% rename from tests/test_rendertools.py rename to tests/test_flatland_utils_rendertools.py index 14edfee709ce2c3a9110a5af34293d6b5c8d01f4..ff7cbd01b2845d2c84a34187dd61852dfc135bb2 100644 --- a/tests/test_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -79,7 +79,7 @@ def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True) else: - print("Run 'python test_rendertools.py save' to regenerate images") + print("Run 'python test_flatland_utils_rendertools.py save' to regenerate images") test_render_env() diff --git a/tests/test_player.py b/tests/test_player.py deleted file mode 100644 index 757fc90dc14003f936be9b91a29078646f109971..0000000000000000000000000000000000000000 --- a/tests/test_player.py +++ /dev/null @@ -1,6 +0,0 @@ -from examples.play_model import main - - -def test_main(): - main(render=True, n_steps=20, n_trials=2, sGL="PIL") - main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")