From bb4bf54efc890d02ae8e29c397590dcd3eb58d7b Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 20 Jun 2019 14:25:36 +0200
Subject: [PATCH] #62 increase unit test coverage

---
 examples/custom_railmap_example.py         |   2 +-
 flatland/core/grid/__init__.py             |   0
 flatland/core/grid/grid4.py                | 212 +++++++++
 flatland/core/grid/grid8.py                | 203 ++++++++
 flatland/core/grid/rail_env_grid.py        | 124 +++++
 flatland/core/transition_map.py            |  39 +-
 flatland/core/transitions.py               | 527 +--------------------
 flatland/envs/env_utils.py                 |   2 +-
 flatland/envs/generators.py                |   2 +-
 flatland/envs/observations.py              |   2 +-
 flatland/utils/graphics_pil.py             |   2 +-
 flatland/utils/svg.py                      |   2 +-
 tests/test_flatland_core_transition_map.py |  16 +-
 tests/test_flatland_core_transitions.py    |   5 +-
 tests/test_flatland_envs_env_utils.py      |   2 +-
 tests/test_flatland_envs_predictions.py    |   2 +-
 tests/test_flatland_envs_rail_env.py       |   3 +-
 17 files changed, 582 insertions(+), 563 deletions(-)
 create mode 100644 flatland/core/grid/__init__.py
 create mode 100644 flatland/core/grid/grid4.py
 create mode 100644 flatland/core/grid/grid8.py
 create mode 100644 flatland/core/grid/rail_env_grid.py

diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index 16ec480f..9ccef3fd 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -3,7 +3,7 @@ import random
 import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap
-from flatland.core.transitions import RailEnvTransitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
diff --git a/flatland/core/grid/__init__.py b/flatland/core/grid/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py
new file mode 100644
index 00000000..3febb521
--- /dev/null
+++ b/flatland/core/grid/grid4.py
@@ -0,0 +1,212 @@
+from enum import IntEnum
+
+import numpy as np
+
+from flatland.core.transitions import Transitions
+
+
+class Grid4TransitionsEnum(IntEnum):
+    NORTH = 0
+    EAST = 1
+    SOUTH = 2
+    WEST = 3
+
+
+class Grid4Transitions(Transitions):
+    """
+    Grid4Transitions class derived from Transitions.
+
+    Special case of `Transitions' over a 2D-grid (FlatLand).
+    Transitions are possible to neighboring cells on the grid if allowed.
+    GridTransitions keeps track of valid transitions supplied as `transitions'
+    list, each represented as a bitmap of 16 bits.
+
+    Whether a transition is allowed or not depends on which direction an agent
+    inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
+    direction the agent wants to move to
+    (North, East, South, West, relative to the cell).
+    Each transition (orientation, direction)
+    can be allowed (1) or forbidden (0).
+
+    For example, in case of no diagonal transitions on the grid, the 16 bits
+    of the transition bitmaps are organized in 4 blocks of 4 bits each, the
+    direction that the agent is facing.
+    E.g., the most-significant 4-bits represent the possible movements (NESW)
+    if the agent is facing North, etc...
+
+    agent's direction:          North    East   South   West
+    agent's allowed movements:  [nesw]   [nesw] [nesw]  [nesw]
+    example:                     1000     0000   0010    0000
+
+    In the example, the agent can move from North to South and viceversa.
+    """
+
+    def __init__(self, transitions):
+        self.transitions = transitions
+        self.sDirs = "NESW"
+        self.lsDirs = list(self.sDirs)
+
+        # row,col delta for each direction
+        self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
+
+    def get_type(self):
+        return np.uint16
+
+    def get_transitions(self, cell_transition, orientation):
+        """
+        Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
+        if no diagonal transitions allowed) available for an agent oriented
+        in direction `orientation' and inside a cell with
+        transitions `cell_transition'.
+
+        Parameters
+        ----------
+        cell_transition : int
+            16 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+
+        Returns
+        -------
+        tuple
+            List of the validity of transitions in the cell.
+
+        """
+        bits = (cell_transition >> ((3 - orientation) * 4))
+        return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
+
+    def set_transitions(self, cell_transition, orientation, new_transitions):
+        """
+        Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
+        if no diagonal transitions allowed) available for an agent
+        oriented in direction `orientation' and inside a cell with transitions
+        `cell_transition'. A new `cell_transition' is returned with
+        the specified bits replaced by `new_transitions'.
+
+        Parameters
+        ----------
+        cell_transition : int
+            16 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        new_transitions : tuple
+            Tuple of new transitions validitiy for the cell.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions validity
+            of `cell_transition' with `new_transitions', for the appropriate
+            `orientation'.
+
+        """
+        mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
+        negmask = ~mask
+
+        new_transitions = \
+            (new_transitions[0] & 1) << 3 | \
+            (new_transitions[1] & 1) << 2 | \
+            (new_transitions[2] & 1) << 1 | \
+            (new_transitions[3] & 1)
+
+        cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
+
+        return cell_transition
+
+    def get_transition(self, cell_transition, orientation, direction):
+        """
+        Get the transition bit (1 value) that determines whether an agent
+        oriented in direction `orientation' and inside a cell with transitions
+        `cell_transition' can move to the cell in direction `direction'
+        relative to the current cell.
+
+        Parameters
+        ----------
+        cell_transition : int
+            16 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        direction : int
+            Direction of movement whose validity is to be tested.
+
+        Returns
+        -------
+        int
+            Validity of the requested transition: 0/1 allowed/not allowed.
+
+        """
+        return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
+
+    def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
+        """
+        Set the transition bit (1 value) that determines whether an agent
+        oriented in direction `orientation' and inside a cell with transitions
+        `cell_transition' can move to the cell in direction `direction'
+        relative to the current cell.
+
+        Parameters
+        ----------
+        cell_transition : int
+            16 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        direction : int
+            Direction of movement whose validity is to be tested.
+        new_transition : int
+            Validity of the requested transition: 0/1 allowed/not allowed.
+        remove_deadends -- boolean, default False
+            remove all deadend transitions.
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions validity
+            of `cell_transition' with `new_transitions', for the appropriate
+            `orientation'.
+
+        """
+        if new_transition:
+            cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
+        else:
+            cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
+
+        if remove_deadends:
+            cell_transition = self.remove_deadends(cell_transition)
+
+        return cell_transition
+
+    def rotate_transition(self, cell_transition, rotation=0):
+        """
+        Clockwise-rotate a 16-bit transition bitmap by
+        rotation={0, 90, 180, 270} degrees.
+
+        Parameters
+        ----------
+        cell_transition : int
+            16 bits used to encode the valid transitions for a cell.
+        rotation : int
+            Angle by which to clock-wise rotate the transition bits in
+            `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions bits
+            with the equivalent bitmap after rotation.
+
+        """
+        # Rotate the individual bits in each block
+        value = cell_transition
+        rotation = rotation // 90
+        for i in range(4):
+            block_tuple = self.get_transitions(value, i)
+            block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
+            value = self.set_transitions(value, i, block_tuple)
+
+        # Rotate the 4-bits blocks
+        value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
+
+        cell_transition = value
+        return cell_transition
+
+    def get_direction_enum(self) -> IntEnum:
+        return Grid4TransitionsEnum
diff --git a/flatland/core/grid/grid8.py b/flatland/core/grid/grid8.py
new file mode 100644
index 00000000..2ba379a5
--- /dev/null
+++ b/flatland/core/grid/grid8.py
@@ -0,0 +1,203 @@
+from enum import IntEnum
+
+import numpy as np
+
+from flatland.core.transitions import Transitions
+
+
+class Grid8TransitionsEnum(IntEnum):
+    NORTH = 0
+    NORTH_EAST = 1
+    EAST = 2
+    SOUTH_EAST = 3
+    SOUTH = 4
+    SOUTH_WEST = 5
+    WEST = 6
+    NORTH_WEST = 7
+
+
+class Grid8Transitions(Transitions):
+    """
+    Grid8Transitions class derived from Transitions.
+
+    Special case of `Transitions' over a 2D-grid (FlatLand).
+    Transitions are possible to neighboring cells on the grid if allowed.
+    GridTransitions keeps track of valid transitions supplied as `transitions'
+    list, each represented as a bitmap of 64 bits.
+
+    0=North, 1=North-East, etc.
+
+    """
+
+    def __init__(self, transitions):
+        self.transitions = transitions
+
+    def get_type(self):
+        return np.uint64
+
+    def get_transitions(self, cell_transition, orientation):
+        """
+        Get the 8 possible transitions.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+
+        Returns
+        -------
+        tuple
+            List of the validity of transitions in the cell.
+
+        """
+        bits = (np.uint64(cell_transition) >> np.uint64((7 - orientation) * 8))
+        cell_transition = (
+            (bits >> np.uint64(7)) & np.uint64(1),
+            (bits >> np.uint64(6)) & np.uint64(1),
+            (bits >> np.uint64(5)) & np.uint64(1),
+            (bits >> np.uint64(4)) & np.uint64(1),
+            (bits >> np.uint64(3)) & np.uint64(1),
+            (bits >> np.uint64(2)) & np.uint64(1),
+            (bits >> np.uint64(1)) & np.uint64(1),
+            bits & np.uint64(1))
+
+        return cell_transition
+
+    def set_transitions(self, cell_transition, orientation, new_transitions):
+        """
+        Set the possible transitions.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        new_transitions : tuple
+            Tuple of new transitions validitiy for the cell.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions validity
+            of `cell_transition' with `new_transitions', for the appropriate
+            `orientation'.
+
+        """
+        mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
+        negmask = ~mask
+
+        new_transitions = \
+            (int(new_transitions[0]) & 1) << 7 | \
+            (int(new_transitions[1]) & 1) << 6 | \
+            (int(new_transitions[2]) & 1) << 5 | \
+            (int(new_transitions[3]) & 1) << 4 | \
+            (int(new_transitions[4]) & 1) << 3 | \
+            (int(new_transitions[5]) & 1) << 2 | \
+            (int(new_transitions[6]) & 1) << 1 | \
+            (int(new_transitions[7]) & 1)
+
+        cell_transition = (int(cell_transition) & negmask) | (new_transitions << ((7 - orientation) * 8))
+
+        return cell_transition
+
+    def get_transition(self, cell_transition, orientation, direction):
+        """
+        Get the transition bit (1 value) that determines whether an agent
+        oriented in direction `orientation' and inside a cell with transitions
+        `cell_transition' can move to the cell in direction `direction'
+        relative to the current cell.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        direction : int
+            Direction of movement whose validity is to be tested.
+
+        Returns
+        -------
+        int
+            Validity of the requested transition: 0/1 allowed/not allowed.
+
+        """
+        return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1
+
+    def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
+
+        """
+        Set the transition bit (1 value) that determines whether an agent
+        oriented in direction `orientation' and inside a cell with transitions
+        `cell_transition' can move to the cell in direction `direction'
+        relative to the current cell.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        direction : int
+            Direction of movement whose validity is to be tested.
+        new_transition : int
+            Validity of the requested transition: 0/1 allowed/not allowed.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions validity
+            of `cell_transition' with `new_transitions', for the appropriate
+            `orientation'.
+
+        """
+        if new_transition:
+            cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
+        else:
+            cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
+
+        return cell_transition
+
+    def rotate_transition(self, cell_transition, rotation=0):
+        """
+        Clockwise-rotate a 64-bit transition bitmap by
+        rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        rotation : int
+            Angle by which to clock-wise rotate the transition bits in
+            `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180,
+            225, 270, 315} degrees.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions bits
+            with the equivalent bitmap after rotation.
+
+        """
+        # TODO: WARNING: this part of the function has never been tested!
+
+        # Rotate the individual bits in each block
+        value = cell_transition
+        rotation = rotation // 45
+        for i in range(8):
+            block_tuple = self.get_transitions(value, i)
+            block_tuple = block_tuple[rotation:] + block_tuple[:rotation]
+            value = self.set_transitions(value, i, block_tuple)
+
+        # Rotate the 8bits blocks
+        value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
+
+        cell_transition = value
+
+        return cell_transition
+
+    def get_direction_enum(self) -> IntEnum:
+        return Grid8TransitionsEnum
diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py
new file mode 100644
index 00000000..efb5ea15
--- /dev/null
+++ b/flatland/core/grid/rail_env_grid.py
@@ -0,0 +1,124 @@
+from flatland.core.grid.grid4 import Grid4Transitions
+
+
+class RailEnvTransitions(Grid4Transitions):
+    """
+    Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
+    of transitions mimicking the types of real Swiss rail connections.
+
+    --------------------------------------------------------------------------
+
+    As no diagonal transitions are allowed in the RailEnv environment, the
+    possible transitions for RailEnv from a cell to its neighboring ones
+    are represented over 16 bits.
+
+    The 16 bits are organized in 4 blocks of 4 bits each, the direction that
+    the agent is facing.
+    E.g., the most-significant 4-bits represent the possible movements (NESW)
+    if the agent is facing North, etc...
+
+    agent's direction:          North    East   South   West
+    agent's allowed movements:  [nesw]   [nesw] [nesw]  [nesw]
+    example:                     1000     0000   0010    0000
+
+    In the example, the agent can move from North to South and viceversa.
+    """
+
+    # Contains the basic transitions;
+    # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions.
+    transition_list = [int('0000000000000000', 2),  # empty cell - Case 0
+                       int('1000000000100000', 2),  # Case 1 - straight
+                       int('1001001000100000', 2),  # Case 2 - simple switch
+                       int('1000010000100001', 2),  # Case 3 - diamond drossing
+                       int('1001011000100001', 2),  # Case 4 - single slip
+                       int('1100110000110011', 2),  # Case 5 - double slip
+                       int('0101001000000010', 2),  # Case 6 - symmetrical
+                       int('0010000000000000', 2),  # Case 7 - dead end
+                       int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
+                       int('0001001000000000', 2),  # Case 1c (9)  - simple turn left
+                       int('1100000000100010', 2)]  # Case 2b (10) - simple switch mirrored
+
+    def __init__(self):
+        super(RailEnvTransitions, self).__init__(
+            transitions=self.transition_list
+        )
+
+        # These bits represent all the possible dead ends
+        self.maskDeadEnds = 0b0010000110000100
+
+        # create this to make validation faster
+        self.transitions_all = set()
+        for index, trans in enumerate(self.transitions):
+            self.transitions_all.add(trans)
+            if index in (2, 4, 6, 7, 8, 9, 10):
+                for _ in range(3):
+                    trans = self.rotate_transition(trans, rotation=90)
+                    self.transitions_all.add(trans)
+            elif index in (1, 5):
+                trans = self.rotate_transition(trans, rotation=90)
+                self.transitions_all.add(trans)
+
+    def print(self, cell_transition):
+        print("  NESW")
+        print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
+        print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
+        print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
+        print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))
+
+    def repr(self, cell_transition, version=0):
+        """
+        Provide a string representation of the cell transitions.
+        This class doesn't represent an individual cell,
+        but a way of interpreting the contents of a cell.
+        So using the ad hoc name repr rather than __repr__.
+        """
+        # binary format string without leading 0b
+        sbinTrans = format(cell_transition, "#018b")[2:]
+        if version == 0:
+            sRepr = " ".join([
+                "{}:{}".format(sDir, sbinTrans[i:(i + 4)])
+                for i, sDir in
+                zip(
+                    range(0, len(sbinTrans), 4),
+                    self.lsDirs)])  # NESW
+            return sRepr
+
+        if version == 1:
+            lsRepr = []
+            for iDirIn in range(0, 4):
+                sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)]
+                if sDirTrans == "0000":
+                    continue
+                sDirsOut = [
+                    self.lsDirs[iDirOut]
+                    for iDirOut in range(0, 4)
+                    if sDirTrans[iDirOut] == "1"]
+                lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut))
+
+            return ", ".join(lsRepr)
+
+    def is_valid(self, cell_transition):
+        """
+        Checks if a cell transition is a valid cell setup.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+
+        Returns
+        -------
+        Boolean
+            True or False
+        """
+        return cell_transition in self.transitions_all
+
+    def has_deadend(self, cell_transition):
+        if cell_transition & self.maskDeadEnds > 0:
+            return True
+        else:
+            return False
+
+    def remove_deadends(self, cell_transition):
+        cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
+        return cell_transition
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 6c9bde42..6c0b92a7 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -6,7 +6,7 @@ import numpy as np
 from importlib_resources import path
 from numpy import array
 
-from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
+from flatland.core.grid.grid4 import Grid4Transitions
 
 
 class TransitionMap:
@@ -73,7 +73,7 @@ class TransitionMap:
 
         Returns
         -------
-        int or float (depending on derived class)
+        int or float (depending on Transitions used)
             Validity of the requested transition (e.g.,
             0/1 allowed/not allowed, a probability in [0,1], etc...)
 
@@ -95,7 +95,7 @@ class TransitionMap:
             Index of the transition to probe, as index in the tuple returned by
             get_transitions(). e.g., the NESW direction of movement, for agents
             on a grid.
-        new_transition : int or float (depending on derived class)
+        new_transition : int or float (depending on Transitions used)
             Validity of the requested transition (e.g.,
             0/1 allowed/not allowed, a probability in [0,1], etc...)
 
@@ -130,10 +130,7 @@ class GridTransitionMap(TransitionMap):
         self.height = height
         self.transitions = transitions
 
-        if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions):
-            self.grid = np.ndarray((height, width), dtype=np.uint16)
-        elif isinstance(self.transitions, Grid8Transitions):
-            self.grid = np.ndarray((height, width), dtype=np.uint64)
+        self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
 
     def get_transitions(self, cell_id):
         """
@@ -156,14 +153,12 @@ class GridTransitionMap(TransitionMap):
             List of the validity of transitions in the cell.
 
         """
+        assert len(cell_id) in (2, 3), \
+            'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.'
         if len(cell_id) == 3:
             return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
         elif len(cell_id) == 2:
             return self.grid[cell_id[0]][cell_id[1]]
-        else:
-            print('GridTransitionMap.get_transitions() ERROR: \
-                   wrong cell_id tuple.')
-            return ()
 
     def set_transitions(self, cell_id, new_transitions):
         """
@@ -182,15 +177,14 @@ class GridTransitionMap(TransitionMap):
             Tuple of new transitions validitiy for the cell.
 
         """
+        assert len(cell_id) in (2, 3), \
+            'GridTransitionMap.set_transitions() ERROR: cell_id tuple must have length 2 or 3.'
         if len(cell_id) == 3:
             self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]],
                                                                                  cell_id[2],
                                                                                  new_transitions)
         elif len(cell_id) == 2:
             self.grid[cell_id[0]][cell_id[1]] = new_transitions
-        else:
-            print('GridTransitionMap.get_transitions() ERROR: \
-                   wrong cell_id tuple.')
 
     def get_transition(self, cell_id, transition_index):
         """
@@ -210,15 +204,14 @@ class GridTransitionMap(TransitionMap):
 
         Returns
         -------
-        int or float (depending on derived class)
+        int or float (depending on Transitions used in the )
             Validity of the requested transition (e.g.,
             0/1 allowed/not allowed, a probability in [0,1], etc...)
 
         """
-        if len(cell_id) != 3:
-            print('GridTransitionMap.get_transition() ERROR: \
-                   wrong cell_id tuple.')
-            return ()
+
+        assert len(cell_id) == 3, \
+            'GridTransitionMap.get_transition() ERROR: cell_id tuple must have length 2 or 3.'
         return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index)
 
     def set_transition(self, cell_id, transition_index, new_transition, remove_deadends=False):
@@ -236,15 +229,13 @@ class GridTransitionMap(TransitionMap):
             Index of the transition to probe, as index in the tuple returned by
             get_transitions(). e.g., the NESW direction of movement, for agents
             on a grid.
-        new_transition : int or float (depending on derived class)
+        new_transition : int or float (depending on Transitions used in the map.)
             Validity of the requested transition (e.g.,
             0/1 allowed/not allowed, a probability in [0,1], etc...)
 
         """
-        if len(cell_id) != 3:
-            print('GridTransitionMap.set_transition() ERROR: \
-                   wrong cell_id tuple.')
-            return
+        assert len(cell_id) == 3, \
+            'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.'
         self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition(
             self.grid[cell_id[0]][cell_id[1]],
             cell_id[2],
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 1c3c924a..29b57c40 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -5,8 +5,6 @@ possible transitions over a 2D grid.
 """
 from enum import IntEnum
 
-import numpy as np
-
 
 class Transitions:
     """
@@ -17,6 +15,9 @@ class Transitions:
     `orientation' and moving into direction `direction')
     """
 
+    def get_type(self):
+        raise NotImplementedError()
+
     def get_transitions(self, cell_transition, orientation):
         """
         Return a tuple of transitions available in a cell specified by
@@ -132,525 +133,3 @@ class Transitions:
 
     def get_direction_enum(self) -> IntEnum:
         raise NotImplementedError()
-
-
-class Grid4TransitionsEnum(IntEnum):
-    NORTH = 0
-    EAST = 1
-    SOUTH = 2
-    WEST = 3
-
-
-class Grid4Transitions(Transitions):
-    """
-    Grid4Transitions class derived from Transitions.
-
-    Special case of `Transitions' over a 2D-grid (FlatLand).
-    Transitions are possible to neighboring cells on the grid if allowed.
-    GridTransitions keeps track of valid transitions supplied as `transitions'
-    list, each represented as a bitmap of 16 bits.
-
-    Whether a transition is allowed or not depends on which direction an agent
-    inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
-    direction the agent wants to move to
-    (North, East, South, West, relative to the cell).
-    Each transition (orientation, direction)
-    can be allowed (1) or forbidden (0).
-
-    For example, in case of no diagonal transitions on the grid, the 16 bits
-    of the transition bitmaps are organized in 4 blocks of 4 bits each, the
-    direction that the agent is facing.
-    E.g., the most-significant 4-bits represent the possible movements (NESW)
-    if the agent is facing North, etc...
-
-    agent's direction:          North    East   South   West
-    agent's allowed movements:  [nesw]   [nesw] [nesw]  [nesw]
-    example:                     1000     0000   0010    0000
-
-    In the example, the agent can move from North to South and viceversa.
-    """
-
-    def __init__(self, transitions):
-        self.transitions = transitions
-        self.sDirs = "NESW"
-        self.lsDirs = list(self.sDirs)
-
-        # row,col delta for each direction
-        self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
-
-    def get_transitions(self, cell_transition, orientation):
-        """
-        Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
-        if no diagonal transitions allowed) available for an agent oriented
-        in direction `orientation' and inside a cell with
-        transitions `cell_transition'.
-
-        Parameters
-        ----------
-        cell_transition : int
-            16 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-
-        Returns
-        -------
-        tuple
-            List of the validity of transitions in the cell.
-
-        """
-        bits = (cell_transition >> ((3 - orientation) * 4))
-        return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
-
-    def set_transitions(self, cell_transition, orientation, new_transitions):
-        """
-        Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
-        if no diagonal transitions allowed) available for an agent
-        oriented in direction `orientation' and inside a cell with transitions
-        `cell_transition'. A new `cell_transition' is returned with
-        the specified bits replaced by `new_transitions'.
-
-        Parameters
-        ----------
-        cell_transition : int
-            16 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-        new_transitions : tuple
-            Tuple of new transitions validitiy for the cell.
-
-        Returns
-        -------
-        int
-            An updated bitmap that replaces the original transitions validity
-            of `cell_transition' with `new_transitions', for the appropriate
-            `orientation'.
-
-        """
-        mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
-        negmask = ~mask
-
-        new_transitions = \
-            (new_transitions[0] & 1) << 3 | \
-            (new_transitions[1] & 1) << 2 | \
-            (new_transitions[2] & 1) << 1 | \
-            (new_transitions[3] & 1)
-
-        cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
-
-        return cell_transition
-
-    def get_transition(self, cell_transition, orientation, direction):
-        """
-        Get the transition bit (1 value) that determines whether an agent
-        oriented in direction `orientation' and inside a cell with transitions
-        `cell_transition' can move to the cell in direction `direction'
-        relative to the current cell.
-
-        Parameters
-        ----------
-        cell_transition : int
-            16 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-        direction : int
-            Direction of movement whose validity is to be tested.
-
-        Returns
-        -------
-        int
-            Validity of the requested transition: 0/1 allowed/not allowed.
-
-        """
-        return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
-
-    def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
-        """
-        Set the transition bit (1 value) that determines whether an agent
-        oriented in direction `orientation' and inside a cell with transitions
-        `cell_transition' can move to the cell in direction `direction'
-        relative to the current cell.
-
-        Parameters
-        ----------
-        cell_transition : int
-            16 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-        direction : int
-            Direction of movement whose validity is to be tested.
-        new_transition : int
-            Validity of the requested transition: 0/1 allowed/not allowed.
-        remove_deadends -- boolean, default False
-            remove all deadend transitions.
-        Returns
-        -------
-        int
-            An updated bitmap that replaces the original transitions validity
-            of `cell_transition' with `new_transitions', for the appropriate
-            `orientation'.
-
-        """
-        if new_transition:
-            cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
-        else:
-            cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
-
-        if remove_deadends:
-            cell_transition = self.remove_deadends(cell_transition)
-
-        return cell_transition
-
-    def rotate_transition(self, cell_transition, rotation=0):
-        """
-        Clockwise-rotate a 16-bit transition bitmap by
-        rotation={0, 90, 180, 270} degrees.
-
-        Parameters
-        ----------
-        cell_transition : int
-            16 bits used to encode the valid transitions for a cell.
-        rotation : int
-            Angle by which to clock-wise rotate the transition bits in
-            `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees.
-
-        Returns
-        -------
-        int
-            An updated bitmap that replaces the original transitions bits
-            with the equivalent bitmap after rotation.
-
-        """
-        # Rotate the individual bits in each block
-        value = cell_transition
-        rotation = rotation // 90
-        for i in range(4):
-            block_tuple = self.get_transitions(value, i)
-            block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
-            value = self.set_transitions(value, i, block_tuple)
-
-        # Rotate the 4-bits blocks
-        value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
-
-        cell_transition = value
-        return cell_transition
-
-    def get_direction_enum(self) -> IntEnum:
-        return Grid4TransitionsEnum
-
-
-class Grid8TransitionsEnum(IntEnum):
-    NORTH = 0
-    NORTH_EAST = 1
-    EAST = 2
-    SOUTH_EAST = 3
-    SOUTH = 4
-    SOUTH_WEST = 5
-    WEST = 6
-    NORTH_WEST = 7
-
-
-class Grid8Transitions(Transitions):
-    """
-    Grid8Transitions class derived from Transitions.
-
-    Special case of `Transitions' over a 2D-grid (FlatLand).
-    Transitions are possible to neighboring cells on the grid if allowed.
-    GridTransitions keeps track of valid transitions supplied as `transitions'
-    list, each represented as a bitmap of 64 bits.
-
-    0=North, 1=North-East, etc.
-
-    """
-
-    def __init__(self, transitions):
-        self.transitions = transitions
-
-    def get_transitions(self, cell_transition, orientation):
-        """
-        Get the 8 possible transitions.
-
-        Parameters
-        ----------
-        cell_transition : int
-            64 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-
-        Returns
-        -------
-        tuple
-            List of the validity of transitions in the cell.
-
-        """
-        bits = (cell_transition >> ((7 - orientation) * 8))
-        cell_transition = (
-            (bits >> 7) & 1,
-            (bits >> 6) & 1,
-            (bits >> 5) & 1,
-            (bits >> 4) & 1,
-            (bits >> 3) & 1,
-            (bits >> 2) & 1,
-            (bits >> 1) & 1,
-            (bits) & 1)
-
-        return cell_transition
-
-    def set_transitions(self, cell_transition, orientation, new_transitions):
-        """
-        Set the possible transitions.
-
-        Parameters
-        ----------
-        cell_transition : int
-            64 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-        new_transitions : tuple
-            Tuple of new transitions validitiy for the cell.
-
-        Returns
-        -------
-        int
-            An updated bitmap that replaces the original transitions validity
-            of `cell_transition' with `new_transitions', for the appropriate
-            `orientation'.
-
-        """
-        mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
-        negmask = ~mask
-
-        new_transitions = \
-            (new_transitions[0] & 1) << 7 | \
-            (new_transitions[1] & 1) << 6 | \
-            (new_transitions[2] & 1) << 5 | \
-            (new_transitions[3] & 1) << 4 | \
-            (new_transitions[4] & 1) << 3 | \
-            (new_transitions[5] & 1) << 2 | \
-            (new_transitions[6] & 1) << 1 | \
-            (new_transitions[7] & 1)
-
-        cell_transition = (cell_transition & negmask) | (new_transitions << ((7 - orientation) * 8))
-
-        return cell_transition
-
-    def get_transition(self, cell_transition, orientation, direction):
-        """
-        Get the transition bit (1 value) that determines whether an agent
-        oriented in direction `orientation' and inside a cell with transitions
-        `cell_transition' can move to the cell in direction `direction'
-        relative to the current cell.
-
-        Parameters
-        ----------
-        cell_transition : int
-            64 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-        direction : int
-            Direction of movement whose validity is to be tested.
-
-        Returns
-        -------
-        int
-            Validity of the requested transition: 0/1 allowed/not allowed.
-
-        """
-        return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1
-
-    def set_transition(self, cell_transition, orientation, direction,
-                       new_transition):
-        """
-        Set the transition bit (1 value) that determines whether an agent
-        oriented in direction `orientation' and inside a cell with transitions
-        `cell_transition' can move to the cell in direction `direction'
-        relative to the current cell.
-
-        Parameters
-        ----------
-        cell_transition : int
-            64 bits used to encode the valid transitions for a cell.
-        orientation : int
-            Orientation of the agent inside the cell.
-        direction : int
-            Direction of movement whose validity is to be tested.
-        new_transition : int
-            Validity of the requested transition: 0/1 allowed/not allowed.
-
-        Returns
-        -------
-        int
-            An updated bitmap that replaces the original transitions validity
-            of `cell_transition' with `new_transitions', for the appropriate
-            `orientation'.
-
-        """
-        if new_transition:
-            cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
-        else:
-            cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
-
-        return cell_transition
-
-    def rotate_transition(self, cell_transition, rotation=0):
-        """
-        Clockwise-rotate a 64-bit transition bitmap by
-        rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees.
-
-        Parameters
-        ----------
-        cell_transition : int
-            64 bits used to encode the valid transitions for a cell.
-        rotation : int
-            Angle by which to clock-wise rotate the transition bits in
-            `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180,
-            225, 270, 315} degrees.
-
-        Returns
-        -------
-        int
-            An updated bitmap that replaces the original transitions bits
-            with the equivalent bitmap after rotation.
-
-        """
-        # TODO: WARNING: this part of the function has never been tested!
-
-        # Rotate the individual bits in each block
-        value = cell_transition
-        rotation = rotation // 45
-        for i in range(8):
-            block_tuple = self.get_transitions(value, i)
-            block_tuple = block_tuple[rotation:] + block_tuple[:rotation]
-            value = self.set_transitions(value, i, block_tuple)
-
-        # Rotate the 8bits blocks
-        value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
-
-        cell_transition = value
-
-        return cell_transition
-
-    def get_direction_enum(self) -> IntEnum:
-        return Grid8TransitionsEnum
-
-
-class RailEnvTransitions(Grid4Transitions):
-    """
-    Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
-    of transitions mimicking the types of real Swiss rail connections.
-
-    --------------------------------------------------------------------------
-
-    As no diagonal transitions are allowed in the RailEnv environment, the
-    possible transitions for RailEnv from a cell to its neighboring ones
-    are represented over 16 bits.
-
-    The 16 bits are organized in 4 blocks of 4 bits each, the direction that
-    the agent is facing.
-    E.g., the most-significant 4-bits represent the possible movements (NESW)
-    if the agent is facing North, etc...
-
-    agent's direction:          North    East   South   West
-    agent's allowed movements:  [nesw]   [nesw] [nesw]  [nesw]
-    example:                     1000     0000   0010    0000
-
-    In the example, the agent can move from North to South and viceversa.
-    """
-
-    # Contains the basic transitions;
-    # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions.
-    transition_list = [int('0000000000000000', 2),  # empty cell - Case 0
-                       int('1000000000100000', 2),  # Case 1 - straight
-                       int('1001001000100000', 2),  # Case 2 - simple switch
-                       int('1000010000100001', 2),  # Case 3 - diamond drossing
-                       int('1001011000100001', 2),  # Case 4 - single slip
-                       int('1100110000110011', 2),  # Case 5 - double slip
-                       int('0101001000000010', 2),  # Case 6 - symmetrical
-                       int('0010000000000000', 2),  # Case 7 - dead end
-                       int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
-                       int('0001001000000000', 2),  # Case 1c (9)  - simple turn left
-                       int('1100000000100010', 2)]  # Case 2b (10) - simple switch mirrored
-
-    def __init__(self):
-        super(RailEnvTransitions, self).__init__(
-            transitions=self.transition_list
-        )
-
-        # These bits represent all the possible dead ends
-        self.maskDeadEnds = 0b0010000110000100
-
-        # create this to make validation faster
-        self.transitions_all = set()
-        for index, trans in enumerate(self.transitions):
-            self.transitions_all.add(trans)
-            if index in (2, 4, 6, 7, 8, 9, 10):
-                for _ in range(3):
-                    trans = self.rotate_transition(trans, rotation=90)
-                    self.transitions_all.add(trans)
-            elif index in (1, 5):
-                trans = self.rotate_transition(trans, rotation=90)
-                self.transitions_all.add(trans)
-
-    def print(self, cell_transition):
-        print("  NESW")
-        print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
-        print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
-        print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
-        print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))
-
-    def repr(self, cell_transition, version=0):
-        """
-        Provide a string representation of the cell transitions.
-        This class doesn't represent an individual cell,
-        but a way of interpreting the contents of a cell.
-        So using the ad hoc name repr rather than __repr__.
-        """
-        # binary format string without leading 0b
-        sbinTrans = format(cell_transition, "#018b")[2:]
-        if version == 0:
-            sRepr = " ".join([
-                "{}:{}".format(sDir, sbinTrans[i:(i + 4)])
-                for i, sDir in
-                zip(
-                    range(0, len(sbinTrans), 4),
-                    self.lsDirs)])  # NESW
-            return sRepr
-
-        if version == 1:
-            lsRepr = []
-            for iDirIn in range(0, 4):
-                sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)]
-                if sDirTrans == "0000":
-                    continue
-                sDirsOut = [
-                    self.lsDirs[iDirOut]
-                    for iDirOut in range(0, 4)
-                    if sDirTrans[iDirOut] == "1"]
-                lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut))
-
-            return ", ".join(lsRepr)
-
-    def is_valid(self, cell_transition):
-        """
-        Checks if a cell transition is a valid cell setup.
-
-        Parameters
-        ----------
-        cell_transition : int
-            64 bits used to encode the valid transitions for a cell.
-
-        Returns
-        -------
-        Boolean
-            True or False
-        """
-        return cell_transition in self.transitions_all
-
-    def has_deadend(self, cell_transition):
-        if cell_transition & self.maskDeadEnds > 0:
-            return True
-        else:
-            return False
-
-    def remove_deadends(self, cell_transition):
-        cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
-        return cell_transition
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index cc4a0015..19da8946 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -7,7 +7,7 @@ a GridTransitionMap object.
 
 import numpy as np
 
-from flatland.core.transitions import Grid4TransitionsEnum
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
 
 
 def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 8b02d445..fa5cdccc 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,7 +1,7 @@
 import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap
-from flatland.core.transitions import RailEnvTransitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror
 from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index a7f91f14..ecb06978 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -6,7 +6,7 @@ from collections import deque
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.transitions import Grid4TransitionsEnum
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.env_utils import coordinate_to_position
 
 
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index c7cb8fa1..6098f434 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -28,7 +28,7 @@ enable_windows_cairo_support()
 from cairosvg import svg2png  # noqa: E402
 from screeninfo import get_monitors  # noqa: E402
 
-from flatland.core.transitions import RailEnvTransitions  # noqa: E402
+from flatland.core.grid.rail_env_grid import RailEnvTransitions  # noqa: E402
 
 
 class PILGL(GraphicsLayer):
diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py
index 249b4cb5..b2e02844 100644
--- a/flatland/utils/svg.py
+++ b/flatland/utils/svg.py
@@ -3,7 +3,7 @@ import re
 
 import svgutils
 
-from flatland.core.transitions import RailEnvTransitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 
 
 class SVG(object):
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index cf1a9206..5117b12a 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -1,13 +1,21 @@
+from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
+from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
 from flatland.core.transition_map import GridTransitionMap
-from flatland.core.transitions import Grid4Transitions, Grid8Transitions, Grid4TransitionsEnum
 
 
 def test_grid4_set_transitions():
     grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
-    grid4_map.set_transition((0, 0), Grid4TransitionsEnum.EAST, 1)
-    actual_transitions  = grid4_map.get_transitions((0,0))
-    assert False
+    assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
+    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
+    assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (1, 0, 0, 0)
+    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0)
+    assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
 
 
 def test_grid8_set_transitions():
     grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
+    assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
+    grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1)
+    assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0)
+    grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
+    assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py
index 47def83c..14cf3073 100644
--- a/tests/test_flatland_core_transitions.py
+++ b/tests/test_flatland_core_transitions.py
@@ -4,7 +4,8 @@
 """Tests for `flatland` package."""
 import numpy as np
 
-from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
+from flatland.core.grid.grid8 import Grid8Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.envs.env_utils import validate_new_transition
 
 
@@ -194,7 +195,7 @@ def test_diagonal_transitions():
 
     # Allowing transition from north to southwest: Facing south, going SW
     north_southwest_transition = \
-        diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0))
+        diagonal_trans_env.set_transitions(0, 4, (0, 0, 0, 0, 0, 1, 0, 0))
 
     assert (diagonal_trans_env.rotate_transition(
         south_northeast_transition, 180) == north_southwest_transition)
diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py
index 25952031..49b619a1 100644
--- a/tests/test_flatland_envs_env_utils.py
+++ b/tests/test_flatland_envs_env_utils.py
@@ -1,7 +1,7 @@
 import numpy as np
 import pytest
 
-from flatland.core.transitions import Grid4TransitionsEnum
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction
 
 depth_to_test = 5
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index f34829d7..16850672 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -3,8 +3,8 @@
 
 import numpy as np
 
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
-from flatland.core.transitions import Grid4TransitionsEnum
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 0279ec7b..3a50c482 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -2,8 +2,9 @@
 # -*- coding: utf-8 -*-
 import numpy as np
 
+from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.core.transitions import Grid4Transitions, RailEnvTransitions
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.agent_utils import EnvAgentStatic
 from flatland.envs.generators import complex_rail_generator
-- 
GitLab