Commit bb4bf54e authored by u214892's avatar u214892
Browse files

#62 increase unit test coverage

parent ab40849f
Pipeline #1188 failed with stage
in 6 minutes and 1 second
......@@ -3,7 +3,7 @@ import random
import numpy as np
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import RailEnvTransitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
......
from enum import IntEnum
import numpy as np
from flatland.core.transitions import Transitions
class Grid4TransitionsEnum(IntEnum):
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
class Grid4Transitions(Transitions):
"""
Grid4Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions'
list, each represented as a bitmap of 16 bits.
Whether a transition is allowed or not depends on which direction an agent
inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
direction the agent wants to move to
(North, East, South, West, relative to the cell).
Each transition (orientation, direction)
can be allowed (1) or forbidden (0).
For example, in case of no diagonal transitions on the grid, the 16 bits
of the transition bitmaps are organized in 4 blocks of 4 bits each, the
direction that the agent is facing.
E.g., the most-significant 4-bits represent the possible movements (NESW)
if the agent is facing North, etc...
agent's direction: North East South West
agent's allowed movements: [nesw] [nesw] [nesw] [nesw]
example: 1000 0000 0010 0000
In the example, the agent can move from North to South and viceversa.
"""
def __init__(self, transitions):
self.transitions = transitions
self.sDirs = "NESW"
self.lsDirs = list(self.sDirs)
# row,col delta for each direction
self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
def get_type(self):
return np.uint16
def get_transitions(self, cell_transition, orientation):
"""
Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent oriented
in direction `orientation' and inside a cell with
transitions `cell_transition'.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
bits = (cell_transition >> ((3 - orientation) * 4))
return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
if no diagonal transitions allowed) available for an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition'. A new `cell_transition' is returned with
the specified bits replaced by `new_transitions'.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
negmask = ~mask
new_transitions = \
(new_transitions[0] & 1) << 3 | \
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
return cell_transition
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
Returns
-------
int
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
new_transition : int
Validity of the requested transition: 0/1 allowed/not allowed.
remove_deadends -- boolean, default False
remove all deadend transitions.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
if new_transition:
cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
else:
cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
if remove_deadends:
cell_transition = self.remove_deadends(cell_transition)
return cell_transition
def rotate_transition(self, cell_transition, rotation=0):
"""
Clockwise-rotate a 16-bit transition bitmap by
rotation={0, 90, 180, 270} degrees.
Parameters
----------
cell_transition : int
16 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees.
Returns
-------
int
An updated bitmap that replaces the original transitions bits
with the equivalent bitmap after rotation.
"""
# Rotate the individual bits in each block
value = cell_transition
rotation = rotation // 90
for i in range(4):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
cell_transition = value
return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid4TransitionsEnum
from enum import IntEnum
import numpy as np
from flatland.core.transitions import Transitions
class Grid8TransitionsEnum(IntEnum):
NORTH = 0
NORTH_EAST = 1
EAST = 2
SOUTH_EAST = 3
SOUTH = 4
SOUTH_WEST = 5
WEST = 6
NORTH_WEST = 7
class Grid8Transitions(Transitions):
"""
Grid8Transitions class derived from Transitions.
Special case of `Transitions' over a 2D-grid (FlatLand).
Transitions are possible to neighboring cells on the grid if allowed.
GridTransitions keeps track of valid transitions supplied as `transitions'
list, each represented as a bitmap of 64 bits.
0=North, 1=North-East, etc.
"""
def __init__(self, transitions):
self.transitions = transitions
def get_type(self):
return np.uint64
def get_transitions(self, cell_transition, orientation):
"""
Get the 8 possible transitions.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
bits = (np.uint64(cell_transition) >> np.uint64((7 - orientation) * 8))
cell_transition = (
(bits >> np.uint64(7)) & np.uint64(1),
(bits >> np.uint64(6)) & np.uint64(1),
(bits >> np.uint64(5)) & np.uint64(1),
(bits >> np.uint64(4)) & np.uint64(1),
(bits >> np.uint64(3)) & np.uint64(1),
(bits >> np.uint64(2)) & np.uint64(1),
(bits >> np.uint64(1)) & np.uint64(1),
bits & np.uint64(1))
return cell_transition
def set_transitions(self, cell_transition, orientation, new_transitions):
"""
Set the possible transitions.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
negmask = ~mask
new_transitions = \
(int(new_transitions[0]) & 1) << 7 | \
(int(new_transitions[1]) & 1) << 6 | \
(int(new_transitions[2]) & 1) << 5 | \
(int(new_transitions[3]) & 1) << 4 | \
(int(new_transitions[4]) & 1) << 3 | \
(int(new_transitions[5]) & 1) << 2 | \
(int(new_transitions[6]) & 1) << 1 | \
(int(new_transitions[7]) & 1)
cell_transition = (int(cell_transition) & negmask) | (new_transitions << ((7 - orientation) * 8))
return cell_transition
def get_transition(self, cell_transition, orientation, direction):
"""
Get the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
Returns
-------
int
Validity of the requested transition: 0/1 allowed/not allowed.
"""
return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1
def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
"""
Set the transition bit (1 value) that determines whether an agent
oriented in direction `orientation' and inside a cell with transitions
`cell_transition' can move to the cell in direction `direction'
relative to the current cell.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
orientation : int
Orientation of the agent inside the cell.
direction : int
Direction of movement whose validity is to be tested.
new_transition : int
Validity of the requested transition: 0/1 allowed/not allowed.
Returns
-------
int
An updated bitmap that replaces the original transitions validity
of `cell_transition' with `new_transitions', for the appropriate
`orientation'.
"""
if new_transition:
cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
else:
cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
return cell_transition
def rotate_transition(self, cell_transition, rotation=0):
"""
Clockwise-rotate a 64-bit transition bitmap by
rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
rotation : int
Angle by which to clock-wise rotate the transition bits in
`cell_transition' by. I.e., rotation={0, 45, 90, 135, 180,
225, 270, 315} degrees.
Returns
-------
int
An updated bitmap that replaces the original transitions bits
with the equivalent bitmap after rotation.
"""
# TODO: WARNING: this part of the function has never been tested!
# Rotate the individual bits in each block
value = cell_transition
rotation = rotation // 45
for i in range(8):
block_tuple = self.get_transitions(value, i)
block_tuple = block_tuple[rotation:] + block_tuple[:rotation]
value = self.set_transitions(value, i, block_tuple)
# Rotate the 8bits blocks
value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
cell_transition = value
return cell_transition
def get_direction_enum(self) -> IntEnum:
return Grid8TransitionsEnum
from flatland.core.grid.grid4 import Grid4Transitions
class RailEnvTransitions(Grid4Transitions):
"""
Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
of transitions mimicking the types of real Swiss rail connections.
--------------------------------------------------------------------------
As no diagonal transitions are allowed in the RailEnv environment, the
possible transitions for RailEnv from a cell to its neighboring ones
are represented over 16 bits.
The 16 bits are organized in 4 blocks of 4 bits each, the direction that
the agent is facing.
E.g., the most-significant 4-bits represent the possible movements (NESW)
if the agent is facing North, etc...
agent's direction: North East South West
agent's allowed movements: [nesw] [nesw] [nesw] [nesw]
example: 1000 0000 0010 0000
In the example, the agent can move from North to South and viceversa.
"""
# Contains the basic transitions;
# the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions.
transition_list = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip
int('1100110000110011', 2), # Case 5 - double slip
int('0101001000000010', 2), # Case 6 - symmetrical
int('0010000000000000', 2), # Case 7 - dead end
int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2), # Case 1c (9) - simple turn left
int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored
def __init__(self):
super(RailEnvTransitions, self).__init__(
transitions=self.transition_list
)
# These bits represent all the possible dead ends
self.maskDeadEnds = 0b0010000110000100
# create this to make validation faster
self.transitions_all = set()
for index, trans in enumerate(self.transitions):
self.transitions_all.add(trans)
if index in (2, 4, 6, 7, 8, 9, 10):
for _ in range(3):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.add(trans)
elif index in (1, 5):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.add(trans)
def print(self, cell_transition):
print(" NESW")
print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))
def repr(self, cell_transition, version=0):
"""
Provide a string representation of the cell transitions.
This class doesn't represent an individual cell,
but a way of interpreting the contents of a cell.
So using the ad hoc name repr rather than __repr__.
"""
# binary format string without leading 0b
sbinTrans = format(cell_transition, "#018b")[2:]
if version == 0:
sRepr = " ".join([
"{}:{}".format(sDir, sbinTrans[i:(i + 4)])
for i, sDir in
zip(
range(0, len(sbinTrans), 4),
self.lsDirs)]) # NESW
return sRepr
if version == 1:
lsRepr = []
for iDirIn in range(0, 4):
sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)]
if sDirTrans == "0000":
continue
sDirsOut = [
self.lsDirs[iDirOut]
for iDirOut in range(0, 4)
if sDirTrans[iDirOut] == "1"]
lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut))
return ", ".join(lsRepr)
def is_valid(self, cell_transition):
"""
Checks if a cell transition is a valid cell setup.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
Returns
-------
Boolean
True or False
"""
return cell_transition in self.transitions_all
def has_deadend(self, cell_transition):
if cell_transition & self.maskDeadEnds > 0:
return True
else:
return False
def remove_deadends(self, cell_transition):
cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
return cell_transition
......@@ -6,7 +6,7 @@ import numpy as np
from importlib_resources import path
from numpy import array
from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
from flatland.core.grid.grid4 import Grid4Transitions
class TransitionMap:
......@@ -73,7 +73,7 @@ class TransitionMap:
Returns
-------
int or float (depending on derived class)
int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
......@@ -95,7 +95,7 @@ class TransitionMap:
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on derived class)
new_transition : int or float (depending on Transitions used)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
......@@ -130,10 +130,7 @@ class GridTransitionMap(TransitionMap):
self.height = height
self.transitions = transitions
if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions):
self.grid = np.ndarray((height, width), dtype=np.uint16)
elif isinstance(self.transitions, Grid8Transitions):
self.grid = np.ndarray((height, width), dtype=np.uint64)
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def get_transitions(self, cell_id):
"""
......@@ -156,14 +153,12 @@ class GridTransitionMap(TransitionMap):
List of the validity of transitions in the cell.
"""
assert len(cell_id) in (2, 3), \
'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.'
if len(cell_id) == 3:
return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
elif len(cell_id) == 2:
return self.grid[cell_id[0]][cell_id[1]]
else:
print('GridTransitionMap.get_transitions() ERROR: \
wrong cell_id tuple.')
return ()
def set_transitions(self, cell_id, new_transitions):
"""
......@@ -182,15 +177,14 @@ class GridTransitionMap(TransitionMap):
Tuple of new transitions validitiy for the cell.
"""
assert len(cell_id) in (2, 3), \