From bb4bf54efc890d02ae8e29c397590dcd3eb58d7b Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 20 Jun 2019 14:25:36 +0200 Subject: [PATCH] #62 increase unit test coverage --- examples/custom_railmap_example.py | 2 +- flatland/core/grid/__init__.py | 0 flatland/core/grid/grid4.py | 212 +++++++++ flatland/core/grid/grid8.py | 203 ++++++++ flatland/core/grid/rail_env_grid.py | 124 +++++ flatland/core/transition_map.py | 39 +- flatland/core/transitions.py | 527 +-------------------- flatland/envs/env_utils.py | 2 +- flatland/envs/generators.py | 2 +- flatland/envs/observations.py | 2 +- flatland/utils/graphics_pil.py | 2 +- flatland/utils/svg.py | 2 +- tests/test_flatland_core_transition_map.py | 16 +- tests/test_flatland_core_transitions.py | 5 +- tests/test_flatland_envs_env_utils.py | 2 +- tests/test_flatland_envs_predictions.py | 2 +- tests/test_flatland_envs_rail_env.py | 3 +- 17 files changed, 582 insertions(+), 563 deletions(-) create mode 100644 flatland/core/grid/__init__.py create mode 100644 flatland/core/grid/grid4.py create mode 100644 flatland/core/grid/grid8.py create mode 100644 flatland/core/grid/rail_env_grid.py diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 16ec480f..9ccef3fd 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -3,7 +3,7 @@ import random import numpy as np from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool diff --git a/flatland/core/grid/__init__.py b/flatland/core/grid/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py new file mode 100644 index 00000000..3febb521 --- /dev/null +++ b/flatland/core/grid/grid4.py @@ -0,0 +1,212 @@ +from enum import IntEnum + +import numpy as np + +from flatland.core.transitions import Transitions + + +class Grid4TransitionsEnum(IntEnum): + NORTH = 0 + EAST = 1 + SOUTH = 2 + WEST = 3 + + +class Grid4Transitions(Transitions): + """ + Grid4Transitions class derived from Transitions. + + Special case of `Transitions' over a 2D-grid (FlatLand). + Transitions are possible to neighboring cells on the grid if allowed. + GridTransitions keeps track of valid transitions supplied as `transitions' + list, each represented as a bitmap of 16 bits. + + Whether a transition is allowed or not depends on which direction an agent + inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which + direction the agent wants to move to + (North, East, South, West, relative to the cell). + Each transition (orientation, direction) + can be allowed (1) or forbidden (0). + + For example, in case of no diagonal transitions on the grid, the 16 bits + of the transition bitmaps are organized in 4 blocks of 4 bits each, the + direction that the agent is facing. + E.g., the most-significant 4-bits represent the possible movements (NESW) + if the agent is facing North, etc... + + agent's direction: North East South West + agent's allowed movements: [nesw] [nesw] [nesw] [nesw] + example: 1000 0000 0010 0000 + + In the example, the agent can move from North to South and viceversa. + """ + + def __init__(self, transitions): + self.transitions = transitions + self.sDirs = "NESW" + self.lsDirs = list(self.sDirs) + + # row,col delta for each direction + self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) + + def get_type(self): + return np.uint16 + + def get_transitions(self, cell_transition, orientation): + """ + Get the 4 possible transitions ((N,E,S,W), 4 elements tuple + if no diagonal transitions allowed) available for an agent oriented + in direction `orientation' and inside a cell with + transitions `cell_transition'. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + + Returns + ------- + tuple + List of the validity of transitions in the cell. + + """ + bits = (cell_transition >> ((3 - orientation) * 4)) + return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1) + + def set_transitions(self, cell_transition, orientation, new_transitions): + """ + Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple + if no diagonal transitions allowed) available for an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition'. A new `cell_transition' is returned with + the specified bits replaced by `new_transitions'. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + new_transitions : tuple + Tuple of new transitions validitiy for the cell. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4)) + negmask = ~mask + + new_transitions = \ + (new_transitions[0] & 1) << 3 | \ + (new_transitions[1] & 1) << 2 | \ + (new_transitions[2] & 1) << 1 | \ + (new_transitions[3] & 1) + + cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4)) + + return cell_transition + + def get_transition(self, cell_transition, orientation, direction): + """ + Get the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + + Returns + ------- + int + Validity of the requested transition: 0/1 allowed/not allowed. + + """ + return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 + + def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): + """ + Set the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + new_transition : int + Validity of the requested transition: 0/1 allowed/not allowed. + remove_deadends -- boolean, default False + remove all deadend transitions. + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + if new_transition: + cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) + else: + cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) + + if remove_deadends: + cell_transition = self.remove_deadends(cell_transition) + + return cell_transition + + def rotate_transition(self, cell_transition, rotation=0): + """ + Clockwise-rotate a 16-bit transition bitmap by + rotation={0, 90, 180, 270} degrees. + + Parameters + ---------- + cell_transition : int + 16 bits used to encode the valid transitions for a cell. + rotation : int + Angle by which to clock-wise rotate the transition bits in + `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees. + + Returns + ------- + int + An updated bitmap that replaces the original transitions bits + with the equivalent bitmap after rotation. + + """ + # Rotate the individual bits in each block + value = cell_transition + rotation = rotation // 90 + for i in range(4): + block_tuple = self.get_transitions(value, i) + block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)] + value = self.set_transitions(value, i, block_tuple) + + # Rotate the 4-bits blocks + value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4)) + + cell_transition = value + return cell_transition + + def get_direction_enum(self) -> IntEnum: + return Grid4TransitionsEnum diff --git a/flatland/core/grid/grid8.py b/flatland/core/grid/grid8.py new file mode 100644 index 00000000..2ba379a5 --- /dev/null +++ b/flatland/core/grid/grid8.py @@ -0,0 +1,203 @@ +from enum import IntEnum + +import numpy as np + +from flatland.core.transitions import Transitions + + +class Grid8TransitionsEnum(IntEnum): + NORTH = 0 + NORTH_EAST = 1 + EAST = 2 + SOUTH_EAST = 3 + SOUTH = 4 + SOUTH_WEST = 5 + WEST = 6 + NORTH_WEST = 7 + + +class Grid8Transitions(Transitions): + """ + Grid8Transitions class derived from Transitions. + + Special case of `Transitions' over a 2D-grid (FlatLand). + Transitions are possible to neighboring cells on the grid if allowed. + GridTransitions keeps track of valid transitions supplied as `transitions' + list, each represented as a bitmap of 64 bits. + + 0=North, 1=North-East, etc. + + """ + + def __init__(self, transitions): + self.transitions = transitions + + def get_type(self): + return np.uint64 + + def get_transitions(self, cell_transition, orientation): + """ + Get the 8 possible transitions. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + + Returns + ------- + tuple + List of the validity of transitions in the cell. + + """ + bits = (np.uint64(cell_transition) >> np.uint64((7 - orientation) * 8)) + cell_transition = ( + (bits >> np.uint64(7)) & np.uint64(1), + (bits >> np.uint64(6)) & np.uint64(1), + (bits >> np.uint64(5)) & np.uint64(1), + (bits >> np.uint64(4)) & np.uint64(1), + (bits >> np.uint64(3)) & np.uint64(1), + (bits >> np.uint64(2)) & np.uint64(1), + (bits >> np.uint64(1)) & np.uint64(1), + bits & np.uint64(1)) + + return cell_transition + + def set_transitions(self, cell_transition, orientation, new_transitions): + """ + Set the possible transitions. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + new_transitions : tuple + Tuple of new transitions validitiy for the cell. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8)) + negmask = ~mask + + new_transitions = \ + (int(new_transitions[0]) & 1) << 7 | \ + (int(new_transitions[1]) & 1) << 6 | \ + (int(new_transitions[2]) & 1) << 5 | \ + (int(new_transitions[3]) & 1) << 4 | \ + (int(new_transitions[4]) & 1) << 3 | \ + (int(new_transitions[5]) & 1) << 2 | \ + (int(new_transitions[6]) & 1) << 1 | \ + (int(new_transitions[7]) & 1) + + cell_transition = (int(cell_transition) & negmask) | (new_transitions << ((7 - orientation) * 8)) + + return cell_transition + + def get_transition(self, cell_transition, orientation, direction): + """ + Get the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + + Returns + ------- + int + Validity of the requested transition: 0/1 allowed/not allowed. + + """ + return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1 + + def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): + + """ + Set the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + new_transition : int + Validity of the requested transition: 0/1 allowed/not allowed. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + if new_transition: + cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) + else: + cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) + + return cell_transition + + def rotate_transition(self, cell_transition, rotation=0): + """ + Clockwise-rotate a 64-bit transition bitmap by + rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + rotation : int + Angle by which to clock-wise rotate the transition bits in + `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180, + 225, 270, 315} degrees. + + Returns + ------- + int + An updated bitmap that replaces the original transitions bits + with the equivalent bitmap after rotation. + + """ + # TODO: WARNING: this part of the function has never been tested! + + # Rotate the individual bits in each block + value = cell_transition + rotation = rotation // 45 + for i in range(8): + block_tuple = self.get_transitions(value, i) + block_tuple = block_tuple[rotation:] + block_tuple[:rotation] + value = self.set_transitions(value, i, block_tuple) + + # Rotate the 8bits blocks + value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8)) + + cell_transition = value + + return cell_transition + + def get_direction_enum(self) -> IntEnum: + return Grid8TransitionsEnum diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py new file mode 100644 index 00000000..efb5ea15 --- /dev/null +++ b/flatland/core/grid/rail_env_grid.py @@ -0,0 +1,124 @@ +from flatland.core.grid.grid4 import Grid4Transitions + + +class RailEnvTransitions(Grid4Transitions): + """ + Special case of `GridTransitions' over a 2D-grid, with a pre-defined set + of transitions mimicking the types of real Swiss rail connections. + + -------------------------------------------------------------------------- + + As no diagonal transitions are allowed in the RailEnv environment, the + possible transitions for RailEnv from a cell to its neighboring ones + are represented over 16 bits. + + The 16 bits are organized in 4 blocks of 4 bits each, the direction that + the agent is facing. + E.g., the most-significant 4-bits represent the possible movements (NESW) + if the agent is facing North, etc... + + agent's direction: North East South West + agent's allowed movements: [nesw] [nesw] [nesw] [nesw] + example: 1000 0000 0010 0000 + + In the example, the agent can move from North to South and viceversa. + """ + + # Contains the basic transitions; + # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions. + transition_list = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip + int('1100110000110011', 2), # Case 5 - double slip + int('0101001000000010', 2), # Case 6 - symmetrical + int('0010000000000000', 2), # Case 7 - dead end + int('0100000000000010', 2), # Case 1b (8) - simple turn right + int('0001001000000000', 2), # Case 1c (9) - simple turn left + int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored + + def __init__(self): + super(RailEnvTransitions, self).__init__( + transitions=self.transition_list + ) + + # These bits represent all the possible dead ends + self.maskDeadEnds = 0b0010000110000100 + + # create this to make validation faster + self.transitions_all = set() + for index, trans in enumerate(self.transitions): + self.transitions_all.add(trans) + if index in (2, 4, 6, 7, 8, 9, 10): + for _ in range(3): + trans = self.rotate_transition(trans, rotation=90) + self.transitions_all.add(trans) + elif index in (1, 5): + trans = self.rotate_transition(trans, rotation=90) + self.transitions_all.add(trans) + + def print(self, cell_transition): + print(" NESW") + print("N", format(cell_transition >> (3 * 4) & 0xF, '04b')) + print("E", format(cell_transition >> (2 * 4) & 0xF, '04b')) + print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) + print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) + + def repr(self, cell_transition, version=0): + """ + Provide a string representation of the cell transitions. + This class doesn't represent an individual cell, + but a way of interpreting the contents of a cell. + So using the ad hoc name repr rather than __repr__. + """ + # binary format string without leading 0b + sbinTrans = format(cell_transition, "#018b")[2:] + if version == 0: + sRepr = " ".join([ + "{}:{}".format(sDir, sbinTrans[i:(i + 4)]) + for i, sDir in + zip( + range(0, len(sbinTrans), 4), + self.lsDirs)]) # NESW + return sRepr + + if version == 1: + lsRepr = [] + for iDirIn in range(0, 4): + sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)] + if sDirTrans == "0000": + continue + sDirsOut = [ + self.lsDirs[iDirOut] + for iDirOut in range(0, 4) + if sDirTrans[iDirOut] == "1"] + lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut)) + + return ", ".join(lsRepr) + + def is_valid(self, cell_transition): + """ + Checks if a cell transition is a valid cell setup. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + + Returns + ------- + Boolean + True or False + """ + return cell_transition in self.transitions_all + + def has_deadend(self, cell_transition): + if cell_transition & self.maskDeadEnds > 0: + return True + else: + return False + + def remove_deadends(self, cell_transition): + cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff + return cell_transition diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 6c9bde42..6c0b92a7 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -6,7 +6,7 @@ import numpy as np from importlib_resources import path from numpy import array -from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions +from flatland.core.grid.grid4 import Grid4Transitions class TransitionMap: @@ -73,7 +73,7 @@ class TransitionMap: Returns ------- - int or float (depending on derived class) + int or float (depending on Transitions used) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) @@ -95,7 +95,7 @@ class TransitionMap: Index of the transition to probe, as index in the tuple returned by get_transitions(). e.g., the NESW direction of movement, for agents on a grid. - new_transition : int or float (depending on derived class) + new_transition : int or float (depending on Transitions used) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) @@ -130,10 +130,7 @@ class GridTransitionMap(TransitionMap): self.height = height self.transitions = transitions - if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions): - self.grid = np.ndarray((height, width), dtype=np.uint16) - elif isinstance(self.transitions, Grid8Transitions): - self.grid = np.ndarray((height, width), dtype=np.uint64) + self.grid = np.zeros((height, width), dtype=self.transitions.get_type()) def get_transitions(self, cell_id): """ @@ -156,14 +153,12 @@ class GridTransitionMap(TransitionMap): List of the validity of transitions in the cell. """ + assert len(cell_id) in (2, 3), \ + 'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.' if len(cell_id) == 3: return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2]) elif len(cell_id) == 2: return self.grid[cell_id[0]][cell_id[1]] - else: - print('GridTransitionMap.get_transitions() ERROR: \ - wrong cell_id tuple.') - return () def set_transitions(self, cell_id, new_transitions): """ @@ -182,15 +177,14 @@ class GridTransitionMap(TransitionMap): Tuple of new transitions validitiy for the cell. """ + assert len(cell_id) in (2, 3), \ + 'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.' if len(cell_id) == 3: self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2], new_transitions) elif len(cell_id) == 2: self.grid[cell_id[0]][cell_id[1]] = new_transitions - else: - print('GridTransitionMap.get_transitions() ERROR: \ - wrong cell_id tuple.') def get_transition(self, cell_id, transition_index): """ @@ -210,15 +204,14 @@ class GridTransitionMap(TransitionMap): Returns ------- - int or float (depending on derived class) + int or float (depending on Transitions used in the ) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) """ - if len(cell_id) != 3: - print('GridTransitionMap.get_transition() ERROR: \ - wrong cell_id tuple.') - return () + + assert len(cell_id) == 3, \ + 'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.' return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index) def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False): @@ -236,15 +229,13 @@ class GridTransitionMap(TransitionMap): Index of the transition to probe, as index in the tuple returned by get_transitions(). e.g., the NESW direction of movement, for agents on a grid. - new_transition : int or float (depending on derived class) + new_transition : int or float (depending on Transitions used in the map.) Validity of the requested transition (e.g., 0/1 allowed/not allowed, a probability in [0,1], etc...) """ - if len(cell_id) != 3: - print('GridTransitionMap.set_transition() ERROR: \ - wrong cell_id tuple.') - return + assert len(cell_id) == 3, \ + 'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.' self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition( self.grid[cell_id[0]][cell_id[1]], cell_id[2], diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 1c3c924a..29b57c40 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -5,8 +5,6 @@ possible transitions over a 2D grid. """ from enum import IntEnum -import numpy as np - class Transitions: """ @@ -17,6 +15,9 @@ class Transitions: `orientation' and moving into direction `direction') """ + def get_type(self): + raise NotImplementedError() + def get_transitions(self, cell_transition, orientation): """ Return a tuple of transitions available in a cell specified by @@ -132,525 +133,3 @@ class Transitions: def get_direction_enum(self) -> IntEnum: raise NotImplementedError() - - -class Grid4TransitionsEnum(IntEnum): - NORTH = 0 - EAST = 1 - SOUTH = 2 - WEST = 3 - - -class Grid4Transitions(Transitions): - """ - Grid4Transitions class derived from Transitions. - - Special case of `Transitions' over a 2D-grid (FlatLand). - Transitions are possible to neighboring cells on the grid if allowed. - GridTransitions keeps track of valid transitions supplied as `transitions' - list, each represented as a bitmap of 16 bits. - - Whether a transition is allowed or not depends on which direction an agent - inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which - direction the agent wants to move to - (North, East, South, West, relative to the cell). - Each transition (orientation, direction) - can be allowed (1) or forbidden (0). - - For example, in case of no diagonal transitions on the grid, the 16 bits - of the transition bitmaps are organized in 4 blocks of 4 bits each, the - direction that the agent is facing. - E.g., the most-significant 4-bits represent the possible movements (NESW) - if the agent is facing North, etc... - - agent's direction: North East South West - agent's allowed movements: [nesw] [nesw] [nesw] [nesw] - example: 1000 0000 0010 0000 - - In the example, the agent can move from North to South and viceversa. - """ - - def __init__(self, transitions): - self.transitions = transitions - self.sDirs = "NESW" - self.lsDirs = list(self.sDirs) - - # row,col delta for each direction - self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) - - def get_transitions(self, cell_transition, orientation): - """ - Get the 4 possible transitions ((N,E,S,W), 4 elements tuple - if no diagonal transitions allowed) available for an agent oriented - in direction `orientation' and inside a cell with - transitions `cell_transition'. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - - Returns - ------- - tuple - List of the validity of transitions in the cell. - - """ - bits = (cell_transition >> ((3 - orientation) * 4)) - return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1) - - def set_transitions(self, cell_transition, orientation, new_transitions): - """ - Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple - if no diagonal transitions allowed) available for an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition'. A new `cell_transition' is returned with - the specified bits replaced by `new_transitions'. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - new_transitions : tuple - Tuple of new transitions validitiy for the cell. - - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4)) - negmask = ~mask - - new_transitions = \ - (new_transitions[0] & 1) << 3 | \ - (new_transitions[1] & 1) << 2 | \ - (new_transitions[2] & 1) << 1 | \ - (new_transitions[3] & 1) - - cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4)) - - return cell_transition - - def get_transition(self, cell_transition, orientation, direction): - """ - Get the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - - Returns - ------- - int - Validity of the requested transition: 0/1 allowed/not allowed. - - """ - return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 - - def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): - """ - Set the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - new_transition : int - Validity of the requested transition: 0/1 allowed/not allowed. - remove_deadends -- boolean, default False - remove all deadend transitions. - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - if new_transition: - cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) - else: - cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) - - if remove_deadends: - cell_transition = self.remove_deadends(cell_transition) - - return cell_transition - - def rotate_transition(self, cell_transition, rotation=0): - """ - Clockwise-rotate a 16-bit transition bitmap by - rotation={0, 90, 180, 270} degrees. - - Parameters - ---------- - cell_transition : int - 16 bits used to encode the valid transitions for a cell. - rotation : int - Angle by which to clock-wise rotate the transition bits in - `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees. - - Returns - ------- - int - An updated bitmap that replaces the original transitions bits - with the equivalent bitmap after rotation. - - """ - # Rotate the individual bits in each block - value = cell_transition - rotation = rotation // 90 - for i in range(4): - block_tuple = self.get_transitions(value, i) - block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)] - value = self.set_transitions(value, i, block_tuple) - - # Rotate the 4-bits blocks - value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4)) - - cell_transition = value - return cell_transition - - def get_direction_enum(self) -> IntEnum: - return Grid4TransitionsEnum - - -class Grid8TransitionsEnum(IntEnum): - NORTH = 0 - NORTH_EAST = 1 - EAST = 2 - SOUTH_EAST = 3 - SOUTH = 4 - SOUTH_WEST = 5 - WEST = 6 - NORTH_WEST = 7 - - -class Grid8Transitions(Transitions): - """ - Grid8Transitions class derived from Transitions. - - Special case of `Transitions' over a 2D-grid (FlatLand). - Transitions are possible to neighboring cells on the grid if allowed. - GridTransitions keeps track of valid transitions supplied as `transitions' - list, each represented as a bitmap of 64 bits. - - 0=North, 1=North-East, etc. - - """ - - def __init__(self, transitions): - self.transitions = transitions - - def get_transitions(self, cell_transition, orientation): - """ - Get the 8 possible transitions. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - - Returns - ------- - tuple - List of the validity of transitions in the cell. - - """ - bits = (cell_transition >> ((7 - orientation) * 8)) - cell_transition = ( - (bits >> 7) & 1, - (bits >> 6) & 1, - (bits >> 5) & 1, - (bits >> 4) & 1, - (bits >> 3) & 1, - (bits >> 2) & 1, - (bits >> 1) & 1, - (bits) & 1) - - return cell_transition - - def set_transitions(self, cell_transition, orientation, new_transitions): - """ - Set the possible transitions. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - new_transitions : tuple - Tuple of new transitions validitiy for the cell. - - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8)) - negmask = ~mask - - new_transitions = \ - (new_transitions[0] & 1) << 7 | \ - (new_transitions[1] & 1) << 6 | \ - (new_transitions[2] & 1) << 5 | \ - (new_transitions[3] & 1) << 4 | \ - (new_transitions[4] & 1) << 3 | \ - (new_transitions[5] & 1) << 2 | \ - (new_transitions[6] & 1) << 1 | \ - (new_transitions[7] & 1) - - cell_transition = (cell_transition & negmask) | (new_transitions << ((7 - orientation) * 8)) - - return cell_transition - - def get_transition(self, cell_transition, orientation, direction): - """ - Get the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - - Returns - ------- - int - Validity of the requested transition: 0/1 allowed/not allowed. - - """ - return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1 - - def set_transition(self, cell_transition, orientation, direction, - new_transition): - """ - Set the transition bit (1 value) that determines whether an agent - oriented in direction `orientation' and inside a cell with transitions - `cell_transition' can move to the cell in direction `direction' - relative to the current cell. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - orientation : int - Orientation of the agent inside the cell. - direction : int - Direction of movement whose validity is to be tested. - new_transition : int - Validity of the requested transition: 0/1 allowed/not allowed. - - Returns - ------- - int - An updated bitmap that replaces the original transitions validity - of `cell_transition' with `new_transitions', for the appropriate - `orientation'. - - """ - if new_transition: - cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) - else: - cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) - - return cell_transition - - def rotate_transition(self, cell_transition, rotation=0): - """ - Clockwise-rotate a 64-bit transition bitmap by - rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - rotation : int - Angle by which to clock-wise rotate the transition bits in - `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180, - 225, 270, 315} degrees. - - Returns - ------- - int - An updated bitmap that replaces the original transitions bits - with the equivalent bitmap after rotation. - - """ - # TODO: WARNING: this part of the function has never been tested! - - # Rotate the individual bits in each block - value = cell_transition - rotation = rotation // 45 - for i in range(8): - block_tuple = self.get_transitions(value, i) - block_tuple = block_tuple[rotation:] + block_tuple[:rotation] - value = self.set_transitions(value, i, block_tuple) - - # Rotate the 8bits blocks - value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8)) - - cell_transition = value - - return cell_transition - - def get_direction_enum(self) -> IntEnum: - return Grid8TransitionsEnum - - -class RailEnvTransitions(Grid4Transitions): - """ - Special case of `GridTransitions' over a 2D-grid, with a pre-defined set - of transitions mimicking the types of real Swiss rail connections. - - -------------------------------------------------------------------------- - - As no diagonal transitions are allowed in the RailEnv environment, the - possible transitions for RailEnv from a cell to its neighboring ones - are represented over 16 bits. - - The 16 bits are organized in 4 blocks of 4 bits each, the direction that - the agent is facing. - E.g., the most-significant 4-bits represent the possible movements (NESW) - if the agent is facing North, etc... - - agent's direction: North East South West - agent's allowed movements: [nesw] [nesw] [nesw] [nesw] - example: 1000 0000 0010 0000 - - In the example, the agent can move from North to South and viceversa. - """ - - # Contains the basic transitions; - # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions. - transition_list = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip - int('1100110000110011', 2), # Case 5 - double slip - int('0101001000000010', 2), # Case 6 - symmetrical - int('0010000000000000', 2), # Case 7 - dead end - int('0100000000000010', 2), # Case 1b (8) - simple turn right - int('0001001000000000', 2), # Case 1c (9) - simple turn left - int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored - - def __init__(self): - super(RailEnvTransitions, self).__init__( - transitions=self.transition_list - ) - - # These bits represent all the possible dead ends - self.maskDeadEnds = 0b0010000110000100 - - # create this to make validation faster - self.transitions_all = set() - for index, trans in enumerate(self.transitions): - self.transitions_all.add(trans) - if index in (2, 4, 6, 7, 8, 9, 10): - for _ in range(3): - trans = self.rotate_transition(trans, rotation=90) - self.transitions_all.add(trans) - elif index in (1, 5): - trans = self.rotate_transition(trans, rotation=90) - self.transitions_all.add(trans) - - def print(self, cell_transition): - print(" NESW") - print("N", format(cell_transition >> (3 * 4) & 0xF, '04b')) - print("E", format(cell_transition >> (2 * 4) & 0xF, '04b')) - print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) - print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) - - def repr(self, cell_transition, version=0): - """ - Provide a string representation of the cell transitions. - This class doesn't represent an individual cell, - but a way of interpreting the contents of a cell. - So using the ad hoc name repr rather than __repr__. - """ - # binary format string without leading 0b - sbinTrans = format(cell_transition, "#018b")[2:] - if version == 0: - sRepr = " ".join([ - "{}:{}".format(sDir, sbinTrans[i:(i + 4)]) - for i, sDir in - zip( - range(0, len(sbinTrans), 4), - self.lsDirs)]) # NESW - return sRepr - - if version == 1: - lsRepr = [] - for iDirIn in range(0, 4): - sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)] - if sDirTrans == "0000": - continue - sDirsOut = [ - self.lsDirs[iDirOut] - for iDirOut in range(0, 4) - if sDirTrans[iDirOut] == "1"] - lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut)) - - return ", ".join(lsRepr) - - def is_valid(self, cell_transition): - """ - Checks if a cell transition is a valid cell setup. - - Parameters - ---------- - cell_transition : int - 64 bits used to encode the valid transitions for a cell. - - Returns - ------- - Boolean - True or False - """ - return cell_transition in self.transitions_all - - def has_deadend(self, cell_transition): - if cell_transition & self.maskDeadEnds > 0: - return True - else: - return False - - def remove_deadends(self, cell_transition): - cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff - return cell_transition diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index cc4a0015..19da8946 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -7,7 +7,7 @@ a GridTransitionMap object. import numpy as np -from flatland.core.transitions import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum def get_direction(pos1, pos2) -> Grid4TransitionsEnum: diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 8b02d445..fa5cdccc 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,7 +1,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a7f91f14..ecb06978 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -6,7 +6,7 @@ from collections import deque import numpy as np from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.transitions import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.env_utils import coordinate_to_position diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index c7cb8fa1..6098f434 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -28,7 +28,7 @@ enable_windows_cairo_support() from cairosvg import svg2png # noqa: E402 from screeninfo import get_monitors # noqa: E402 -from flatland.core.transitions import RailEnvTransitions # noqa: E402 +from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402 class PILGL(GraphicsLayer): diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index 249b4cb5..b2e02844 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -3,7 +3,7 @@ import re import svgutils -from flatland.core.transitions import RailEnvTransitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions class SVG(object): diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index cf1a9206..5117b12a 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -1,13 +1,21 @@ +from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum +from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import Grid4Transitions, Grid8Transitions, Grid4TransitionsEnum def test_grid4_set_transitions(): grid4_map = GridTransitionMap(2, 2, Grid4Transitions([])) - grid4_map.set_transition((0, 0), Grid4TransitionsEnum.EAST, 1) - actual_transitions = grid4_map.get_transitions((0,0)) - assert False + assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0) + grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1) + assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (1, 0, 0, 0) + grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0) + assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0) def test_grid8_set_transitions(): grid8_map = GridTransitionMap(2, 2, Grid8Transitions([])) + assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0) + grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1) + assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0) + grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0) + assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0) diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 47def83c..14cf3073 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -4,7 +4,8 @@ """Tests for `flatland` package.""" import numpy as np -from flatland.core.transitions import RailEnvTransitions, Grid8Transitions +from flatland.core.grid.grid8 import Grid8Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.env_utils import validate_new_transition @@ -194,7 +195,7 @@ def test_diagonal_transitions(): # Allowing transition from north to southwest: Facing south, going SW north_southwest_transition = \ - diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0)) + diagonal_trans_env.set_transitions(0, 4, (0, 0, 0, 0, 0, 1, 0, 0)) assert (diagonal_trans_env.rotate_transition( south_northeast_transition, 180) == north_southwest_transition) diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py index 25952031..49b619a1 100644 --- a/tests/test_flatland_envs_env_utils.py +++ b/tests/test_flatland_envs_env_utils.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from flatland.core.transitions import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction depth_to_test = 5 diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index f34829d7..16850672 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -3,8 +3,8 @@ import numpy as np +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.transition_map import GridTransitionMap, Grid4Transitions -from flatland.core.transitions import Grid4TransitionsEnum from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 0279ec7b..3a50c482 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -2,8 +2,9 @@ # -*- coding: utf-8 -*- import numpy as np +from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import Grid4Transitions, RailEnvTransitions from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.generators import complex_rail_generator -- GitLab