From ef60e4ce5df50ea490bdf9ae0363d960e9cc8fb5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mattias=20Ljungstr=C3=B6m?= <ml@mljx.io>
Date: Tue, 9 Apr 2019 19:55:28 +0200
Subject: [PATCH] refactor GridTransition into Grid4, Grid8Transition.

---
 flatland/core/env.py                 |  12 +-
 flatland/core/transitions.py         | 386 +++++++++++++++++----------
 flatland/utils/rail_env_generator.py |   8 +-
 flatland/utils/rendertools.py        |   7 +-
 tests/test_environments.py           |  22 +-
 tests/test_transitions.py            |  32 +--
 6 files changed, 278 insertions(+), 189 deletions(-)

diff --git a/flatland/core/env.py b/flatland/core/env.py
index 2ecee63..bf6df54 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -144,9 +144,7 @@ class RailEnv:
 
         self.agents_handles = list(range(self.number_of_agents))
 
-        self.t_utils = RailEnvTransitions()
-        # TODO : bad hack for pylint 80 characters per line; shortened function
-        self.gtfotd = self.t_utils.get_transition_from_orientation_to_direction
+        self.trans = RailEnvTransitions()
 
     def get_agent_handles(self):
         return self.agents_handles
@@ -177,7 +175,7 @@ class RailEnv:
                 valid_movements = []
                 for direction in range(4):
                     position = self.agents_position[i]
-                    moves = self.t_utils.get_transitions_from_orientation(
+                    moves = self.trans.get_transitions(
                              self.rail[position[0]][position[1]], direction)
                     for move_index in range(4):
                         if moves[move_index]:
@@ -272,7 +270,7 @@ class RailEnv:
                         elif direction == 3:
                             reverse_direction = 1
 
-                        valid_transition = self.gtfotd(
+                        valid_transition = self.trans.get_transition(
                                             self.rail[pos[0]][pos[1]],
                                             reverse_direction,
                                             reverse_direction)
@@ -295,7 +293,7 @@ class RailEnv:
                 else:
                     new_cell_isValid = False
 
-                transition_isValid = self.gtfotd(
+                transition_isValid = self.trans.get_transition(
                      self.rail[pos[0]][pos[1]],
                      direction,
                      movement)
@@ -364,7 +362,7 @@ class RailEnv:
                 return 1
             if node not in visited:
                 visited.add(node)
-                moves = self.t_utils.get_transitions_from_orientation(
+                moves = self.trans.get_transitions(
                          self.rail[node[0][0]][node[0][1]], node[1])
                 for move_index in range(4):
                     if moves[move_index]:
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 9c4f05e..78e76f5 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -14,7 +14,7 @@ class Transitions:
     `orientation' and moving into direction `direction')
     """
 
-    def get_transitions_from_orientation(self, cell_transition, orientation):
+    def get_transitions(self, cell_transition, orientation):
         """
         Return a tuple of transitions available in a cell specified by
         `cell_transition' for an agent facing direction `orientation'
@@ -39,8 +39,7 @@ class Transitions:
         """
         raise NotImplementedError()
 
-    def set_transitions_from_orientation(self, cell_transition, orientation,
-                                         new_transitions):
+    def set_transitions(self, cell_transition, orientation, new_transitions):
         """
         Return a `cell_transition' specification where the transitions
         available for an agent facing direction `orientation' are replaced
@@ -68,8 +67,7 @@ class Transitions:
         """
         raise NotImplementedError()
 
-    def get_transition_from_orientation_to_direction(self, cell_transition,
-                                                     orientation, direction):
+    def get_transition(self, cell_transition, orientation, direction):
         """
         Return the status of whether an agent oriented in directions
         `orientation' and inside a cell with transitions `cell_transition'
@@ -96,11 +94,8 @@ class Transitions:
         """
         raise NotImplementedError()
 
-    def set_transition_from_orientation_to_direction(self,
-                                                     cell_transition,
-                                                     orientation,
-                                                     direction,
-                                                     new_transition):
+    def set_transition(self, cell_transition, orientation, direction,
+                       new_transition):
         """
         Return a `cell_transition' specification where the status of
         whether an agent oriented in direction `orientation' and inside
@@ -133,15 +128,14 @@ class Transitions:
         raise NotImplementedError()
 
 
-class GridTransitions(Transitions):
+class Grid4Transitions(Transitions):
     """
-    GridTransitions class derived from Transitions.
+    Grid4Transitions class derived from Transitions.
 
     Special case of `Transitions' over a 2D-grid (FlatLand).
     Transitions are possible to neighboring cells on the grid if allowed.
     GridTransitions keeps track of valid transitions supplied as `transitions'
-    list, each represented as a bitmap of 16 (allow_diagonal_transitions=False)
-    or 64 bits (allow_diagonal_transitions=True).
+    list, each represented as a bitmap of 16 bits.
 
     Whether a transition is allowed or not depends on which direction an agent
     inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
@@ -163,19 +157,10 @@ class GridTransitions(Transitions):
     In the example, the agent can move from North to South and viceversa.
     """
 
-    def __init__(self,
-                 transitions,
-                 allow_diagonal_transitions=False
-                 ):
-
-        if allow_diagonal_transitions:
-            self.number_of_cell_neighbors = 8
-        else:
-            self.number_of_cell_neighbors = 4
-
+    def __init__(self, transitions):
         self.transitions = transitions
 
-    def get_transitions_from_orientation(self, cell_transition, orientation):
+    def get_transitions(self, cell_transition, orientation):
         """
         Get the 4 possible transitions ((N,E,S,W), 4 elements tuple
         if no diagonal transitions allowed) available for an agent oriented
@@ -185,7 +170,7 @@ class GridTransitions(Transitions):
         Parameters
         ----------
         cell_transition : int
-            16 or 64 bits used to encode the valid transitions for a cell.
+            16 bits used to encode the valid transitions for a cell.
         orientation : int
             Orientation of the agent inside the cell.
 
@@ -195,28 +180,10 @@ class GridTransitions(Transitions):
             List of the validity of transitions in the cell.
 
         """
-        if self.number_of_cell_neighbors == 4:
-            bits = (cell_transition >> ((3-orientation)*4))
-            cell_transition = ((bits >> 3) & 1, (bits >> 2) & 1,
-                               (bits >> 1) & 1, (bits) & 1)
-        elif self.number_of_cell_neighbors == 8:
-            bits = (cell_transition >> ((7-orientation)*8))
-            cell_transition = (
-                (bits >> 7) & 1,
-                (bits >> 6) & 1,
-                (bits >> 5) & 1,
-                (bits >> 4) & 1,
-                (bits >> 3) & 1,
-                (bits >> 2) & 1,
-                (bits >> 1) & 1,
-                (bits) & 1)
-        else:
-            raise NotImplementedError()
+        bits = (cell_transition >> ((3-orientation)*4))
+        return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
 
-        return cell_transition
-
-    def set_transitions_from_orientation(self, cell_transition, orientation,
-                                         new_transitions):
+    def set_transitions(self, cell_transition, orientation, new_transitions):
         """
         Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple
         if no diagonal transitions allowed) available for an agent
@@ -227,7 +194,7 @@ class GridTransitions(Transitions):
         Parameters
         ----------
         cell_transition : int
-            16 or 64 bits used to encode the valid transitions for a cell.
+            16 bits used to encode the valid transitions for a cell.
         orientation : int
             Orientation of the agent inside the cell.
         new_transitions : tuple
@@ -241,43 +208,22 @@ class GridTransitions(Transitions):
             `orientation'.
 
         """
-        if self.number_of_cell_neighbors == 4:
-            mask = (1 << ((4-orientation)*4)) - (1 << ((3-orientation)*4))
-            negmask = ~mask
-
-            new_transitions = \
-                (new_transitions[0] & 1) << 3 | \
-                (new_transitions[1] & 1) << 2 | \
-                (new_transitions[2] & 1) << 1 | \
-                (new_transitions[3] & 1)
-
-            cell_transition = \
-                (
-                    cell_transition & negmask) | \
-                (new_transitions << ((3-orientation)*4))
-        elif self.number_of_cell_neighbors == 8:
-            mask = (1 << ((8-orientation)*8)) - (1 << ((7-orientation)*8))
-            negmask = ~mask
-
-            new_transitions = \
-                (new_transitions[0] & 1) << 7 | \
-                (new_transitions[1] & 1) << 6 | \
-                (new_transitions[2] & 1) << 5 | \
-                (new_transitions[3] & 1) << 4 | \
-                (new_transitions[4] & 1) << 3 | \
-                (new_transitions[5] & 1) << 2 | \
-                (new_transitions[6] & 1) << 1 | \
-                (new_transitions[7] & 1)
-
-            cell_transition = (cell_transition & negmask) | (
-                new_transitions << ((7-orientation)*8))
-        else:
-            raise NotImplementedError()
+        mask = (1 << ((4-orientation)*4)) - (1 << ((3-orientation)*4))
+        negmask = ~mask
+
+        new_transitions = \
+            (new_transitions[0] & 1) << 3 | \
+            (new_transitions[1] & 1) << 2 | \
+            (new_transitions[2] & 1) << 1 | \
+            (new_transitions[3] & 1)
+
+        cell_transition = \
+            (cell_transition & negmask) | \
+            (new_transitions << ((3-orientation)*4))
 
         return cell_transition
 
-    def get_transition_from_orientation_to_direction(self, cell_transition,
-                                                     orientation, direction):
+    def get_transition(self, cell_transition, orientation, direction):
         """
         Get the transition bit (1 value) that determines whether an agent
         oriented in direction `orientation' and inside a cell with transitions
@@ -287,7 +233,7 @@ class GridTransitions(Transitions):
         Parameters
         ----------
         cell_transition : int
-            16 or 64 bits used to encode the valid transitions for a cell.
+            16 bits used to encode the valid transitions for a cell.
         orientation : int
             Orientation of the agent inside the cell.
         direction : int
@@ -299,14 +245,10 @@ class GridTransitions(Transitions):
             Validity of the requested transition: 0/1 allowed/not allowed.
 
         """
-        return ((cell_transition >>
-                 ((self.number_of_cell_neighbors-1-orientation) *
-                  self.number_of_cell_neighbors)) >>
-                (self.number_of_cell_neighbors-1-direction)) & 1
-
-    def set_transition_from_orientation_to_direction(self, cell_transition,
-                                                     orientation, direction,
-                                                     new_transition):
+        return ((cell_transition >> ((4-1-orientation) * 4)) >> (4-1-direction)) & 1
+
+    def set_transition(self, cell_transition, orientation, direction,
+                       new_transition):
         """
         Set the transition bit (1 value) that determines whether an agent
         oriented in direction `orientation' and inside a cell with transitions
@@ -316,7 +258,7 @@ class GridTransitions(Transitions):
         Parameters
         ----------
         cell_transition : int
-            16 or 64 bits used to encode the valid transitions for a cell.
+            16 bits used to encode the valid transitions for a cell.
         orientation : int
             Orientation of the agent inside the cell.
         direction : int
@@ -333,34 +275,27 @@ class GridTransitions(Transitions):
 
         """
         if new_transition:
-            cell_transition |= \
-                (1 << ((self.number_of_cell_neighbors-1-orientation) *
-                       self.number_of_cell_neighbors +
-                       (self.number_of_cell_neighbors - 1 - direction)))
+            cell_transition |= (1 << ((4-1-orientation) * 4 + 
+                                (4 - 1 - direction)))
         else:
             cell_transition &= \
-                ~(1 << ((self.number_of_cell_neighbors-1-orientation) *
-                        self.number_of_cell_neighbors +
-                        (self.number_of_cell_neighbors - 1 - direction)))
+                ~(1 << ((4-1-orientation) * 4 +
+                        (4 - 1 - direction)))
 
         return cell_transition
 
     def rotate_transition(self, cell_transition, rotation=0):
         """
-        Clockwise-rotate a 16-bit or 64-bit transition bitmap by
-        rotation={0, 90, 180, 270} degrees in diagonal steps are not allowed,
-        or by rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees if \
-        they are.
+        Clockwise-rotate a 16-bit transition bitmap by
+        rotation={0, 90, 180, 270} degrees.
 
         Parameters
         ----------
         cell_transition : int
-            16 or 64 bits used to encode the valid transitions for a cell.
+            16 bits used to encode the valid transitions for a cell.
         rotation : int
             Angle by which to clock-wise rotate the transition bits in
-            `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees in
-            diagonal steps are not allowed, or by
-            rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees if they are.
+            `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees.
 
         Returns
         -------
@@ -369,48 +304,210 @@ class GridTransitions(Transitions):
             with the equivalent bitmap after rotation.
 
         """
-        if self.number_of_cell_neighbors == 4:
-            # Rotate the individual bits in each block
-            value = cell_transition
-            rotation = rotation // 90
-            for i in range(4):
-                block_tuple = self.get_transitions_from_orientation(value, i)
-                block_tuple = block_tuple[(
-                    4-rotation):] + block_tuple[:(4-rotation)]
-                value = self.set_transitions_from_orientation(
-                    value, i, block_tuple)
-
-            # Rotate the 4bits blocks
-            value = ((value & (2**(rotation*4)-1)) <<
-                     ((4-rotation)*4)) | (value >> (rotation*4))
-
-            cell_transition = value
-
-        elif self.number_of_cell_neighbors == 8:
-            # TODO: WARNING: this part of the function has never been tested!
-
-            # Rotate the individual bits in each block
-            value = cell_transition
-            rotation = rotation // 45
-            for i in range(8):
-                block_tuple = self.get_transitions_from_orientation(value, i)
-                block_tuple = block_tuple[rotation:] + block_tuple[:rotation]
-                value = self.set_transitions_from_orientation(
-                    value, i, block_tuple)
-
-            # Rotate the 8bits blocks
-            value = ((value & (2**(rotation*8)-1)) <<
-                     ((8-rotation)*8)) | (value >> (rotation*8))
-
-            cell_transition = value
+        # Rotate the individual bits in each block
+        value = cell_transition
+        rotation = rotation // 90
+        for i in range(4):
+            block_tuple = self.get_transitions(value, i)
+            block_tuple = block_tuple[(
+                4-rotation):] + block_tuple[:(4-rotation)]
+            value = self.set_transitions(value, i, block_tuple)
+
+        # Rotate the 4-bits blocks
+        value = ((value & (2**(rotation*4)-1)) <<
+                    ((4-rotation)*4)) | (value >> (rotation*4))
+
+        cell_transition = value
+        return cell_transition
+
+
+class Grid8Transitions(Transitions):
+    """
+    Grid8Transitions class derived from Transitions.
 
+    Special case of `Transitions' over a 2D-grid (FlatLand).
+    Transitions are possible to neighboring cells on the grid if allowed.
+    GridTransitions keeps track of valid transitions supplied as `transitions'
+    list, each represented as a bitmap of 64 bits.
+
+    0=North, 1=North-East, etc.
+
+    """
+
+    def __init__(self, transitions):
+        self.transitions = transitions
+
+    def get_transitions(self, cell_transition, orientation):
+        """
+        Get the 8 possible transitions.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+
+        Returns
+        -------
+        tuple
+            List of the validity of transitions in the cell.
+
+        """
+        bits = (cell_transition >> ((7-orientation)*8))
+        cell_transition = (
+            (bits >> 7) & 1,
+            (bits >> 6) & 1,
+            (bits >> 5) & 1,
+            (bits >> 4) & 1,
+            (bits >> 3) & 1,
+            (bits >> 2) & 1,
+            (bits >> 1) & 1,
+            (bits) & 1)
+
+        return cell_transition
+
+    def set_transitions(self, cell_transition, orientation, new_transitions):
+        """
+        Set the possible transitions.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        new_transitions : tuple
+            Tuple of new transitions validitiy for the cell.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions validity
+            of `cell_transition' with `new_transitions', for the appropriate
+            `orientation'.
+
+        """
+        mask = (1 << ((8-orientation)*8)) - (1 << ((7-orientation)*8))
+        negmask = ~mask
+
+        new_transitions = \
+            (new_transitions[0] & 1) << 7 | \
+            (new_transitions[1] & 1) << 6 | \
+            (new_transitions[2] & 1) << 5 | \
+            (new_transitions[3] & 1) << 4 | \
+            (new_transitions[4] & 1) << 3 | \
+            (new_transitions[5] & 1) << 2 | \
+            (new_transitions[6] & 1) << 1 | \
+            (new_transitions[7] & 1)
+
+        cell_transition = (cell_transition & negmask) | (
+            new_transitions << ((7-orientation)*8))
+
+        return cell_transition
+
+    def get_transition(self, cell_transition, orientation, direction):
+        """
+        Get the transition bit (1 value) that determines whether an agent
+        oriented in direction `orientation' and inside a cell with transitions
+        `cell_transition' can move to the cell in direction `direction'
+        relative to the current cell.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        direction : int
+            Direction of movement whose validity is to be tested.
+
+        Returns
+        -------
+        int
+            Validity of the requested transition: 0/1 allowed/not allowed.
+
+        """
+        return ((cell_transition >> ((8-1-orientation) * 8)) >>
+                  (8-1-direction)) & 1
+
+    def set_transition(self, cell_transition, orientation, direction,
+                       new_transition):
+        """
+        Set the transition bit (1 value) that determines whether an agent
+        oriented in direction `orientation' and inside a cell with transitions
+        `cell_transition' can move to the cell in direction `direction'
+        relative to the current cell.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        orientation : int
+            Orientation of the agent inside the cell.
+        direction : int
+            Direction of movement whose validity is to be tested.
+        new_transition : int
+            Validity of the requested transition: 0/1 allowed/not allowed.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions validity
+            of `cell_transition' with `new_transitions', for the appropriate
+            `orientation'.
+
+        """
+        if new_transition:
+            cell_transition |= (1 << ((8-1-orientation) * 8 +
+                                (8 - 1 - direction)))
         else:
-            raise NotImplementedError()
+            cell_transition &= ~(1 << ((8-1-orientation) * 8 +
+                                 (8 - 1 - direction)))
+
+        return cell_transition
+
+    def rotate_transition(self, cell_transition, rotation=0):
+        """
+        Clockwise-rotate a 64-bit transition bitmap by 
+        rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+        rotation : int
+            Angle by which to clock-wise rotate the transition bits in
+            `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180,
+            225, 270, 315} degrees.
+
+        Returns
+        -------
+        int
+            An updated bitmap that replaces the original transitions bits
+            with the equivalent bitmap after rotation.
+
+        """
+        # TODO: WARNING: this part of the function has never been tested!
+
+        # Rotate the individual bits in each block
+        value = cell_transition
+        rotation = rotation // 45
+        for i in range(8):
+            block_tuple = self.get_transitions(value, i)
+            block_tuple = block_tuple[rotation:] + block_tuple[:rotation]
+            value = self.set_transitions(value, i, block_tuple)
+
+        # Rotate the 8bits blocks
+        value = ((value & (2**(rotation*8)-1)) <<
+                    ((8-rotation)*8)) | (value >> (rotation*8))
+
+        cell_transition = value
 
         return cell_transition
 
 
-class RailEnvTransitions(GridTransitions):
+class RailEnvTransitions(Grid4Transitions):
     """
     Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
     of transitions mimicking the types of real Swiss rail connections.
@@ -450,6 +547,5 @@ class RailEnvTransitions(GridTransitions):
 
     def __init__(self):
         super(RailEnvTransitions, self).__init__(
-            transitions=self.transition_list,
-            allow_diagonal_transitions=False
+            transitions=self.transition_list
         )
diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py
index df180e7..4d73d52 100644
--- a/flatland/utils/rail_env_generator.py
+++ b/flatland/utils/rail_env_generator.py
@@ -2,8 +2,8 @@
 The rail_env_generator module defines provides utilities to generate env
 bitmaps for the RailEnv environment.
 """
-import numpy as np
 import random
+import numpy as np
 
 from flatland.core.transitions import RailEnvTransitions
 
@@ -82,8 +82,7 @@ def generate_random_rail(width, height):
     for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
         all_transitions = 0
         for dir_ in range(4):
-            trans = t_utils.get_transitions_from_orientation(
-                     t_utils.transitions[i], dir_)
+            trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
             all_transitions |= (trans[0] << 3) | \
                                (trans[1] << 2) | \
                                (trans[2] << 1) | \
@@ -148,8 +147,7 @@ def generate_random_rail(width, height):
                     max_bit = 0
                     for k in range(4):
                         max_bit |= \
-                         t_utils.get_transition_from_orientation_to_direction(
-                          neigh_trans, k, el[1])
+                         t_utils.get_transition(neigh_trans, k, el[1])
 
                     if max_bit:
                         valid_template[el[0]] = 1
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 1a897d8..af5f68f 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -58,7 +58,7 @@ class RenderTool(object):
                     # transition for next cell
                     oTrans = self.env.rail[rcNext[0]][rcNext[1]]
                     tbTrans = RailEnvTransitions. \
-                        get_transitions_from_orientation(oTrans, iDir)
+                        get_transitions(oTrans, iDir)
                     giTrans = np.where(tbTrans)[0]  # RC list of transitions
                     gTransRCAg = self.__class__.gTransRC[giTrans]
 
@@ -106,7 +106,7 @@ class RenderTool(object):
 
         # TODO: suggest we provide an accessor in RailEnv
         oTrans = self.env.rail[rcPos]   # transition for current cell
-        tbTrans = rt.RETrans.get_transitions_from_orientation(oTrans, iDir)
+        tbTrans = rt.RETrans.get_transitions(oTrans, iDir)
         giTrans = np.where(tbTrans)[0]  # RC list of transitions
 
         # HACK: workaround dead-end transitions
@@ -363,8 +363,7 @@ class RenderTool(object):
                     # renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS)
 
                     if True:
-                        tMoves = RETrans.get_transitions_from_orientation(
-                                    oCell, orientation)
+                        tMoves = RETrans.get_transitions(oCell, orientation)
 
                         # to_ori = (orientation + 2) % 4
                         for to_ori in range(4):
diff --git a/tests/test_environments.py b/tests/test_environments.py
index d5e7dd9..32f8784 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -2,24 +2,22 @@
 # -*- coding: utf-8 -*-
 
 from flatland.core.env import RailEnv
-from flatland.core.transitions import GridTransitions
+from flatland.core.transitions import Grid4Transitions
 import numpy as np
-import random
 
 """Tests for `flatland` package."""
 
 
-
 def test_rail_environment_single_agent():
 
     cells = [int('0000000000000000', 2),  # empty cell - Case 0
-                       int('1000000000100000', 2),  # Case 1 - straight
-                       int('1001001000100000', 2),  # Case 2 - simple switch
-                       int('1000010000100001', 2),  # Case 3 - diamond drossing
-                       int('1001011000100001', 2),  # Case 4 - single slip switch
-                       int('1100110000110011', 2),  # Case 5 - double slip switch
-                       int('0101001000000010', 2),  # Case 6 - symmetrical switch
-                       int('0010000000000000', 2)]  # Case 7 - dead end
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
 
     # We instantiate the following map on a 3x3 grid
     #  _  _
@@ -27,7 +25,7 @@ def test_rail_environment_single_agent():
     # | |  |
     # \_/\_/
 
-    transitions = GridTransitions([], False)
+    transitions = Grid4Transitions([])
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
@@ -51,7 +49,7 @@ def test_rail_environment_single_agent():
 
         # Check that trains are always initialized at a consistent position / direction.
         # They should always be able to go somewhere.
-        assert(transitions.get_transitions_from_orientation(
+        assert(transitions.get_transitions(
             rail_map[rail_env.agents_position[0]],
             rail_env.agents_direction[0]) != (0, 0, 0, 0))
 
diff --git a/tests/test_transitions.py b/tests/test_transitions.py
index 2c59add..f68b836 100644
--- a/tests/test_transitions.py
+++ b/tests/test_transitions.py
@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 
 """Tests for `flatland` package."""
-from flatland.core.transitions import RailEnvTransitions, GridTransitions
+from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
 
 
 def test_valid_railenv_transitions():
@@ -14,36 +14,36 @@ def test_valid_railenv_transitions():
     #            'W': 3}
 
     for i in range(2):
-        assert(rail_env_trans.get_transitions_from_orientation(
+        assert(rail_env_trans.get_transitions(
                     int('1100110000110011', 2), i) == (1, 1, 0, 0))
-        assert(rail_env_trans.get_transitions_from_orientation(
+        assert(rail_env_trans.get_transitions(
                     int('1100110000110011', 2), 2+i) == (0, 0, 1, 1))
 
     no_transition_cell = int('0000000000000000', 2)
 
     for i in range(4):
-        assert(rail_env_trans.get_transitions_from_orientation(
+        assert(rail_env_trans.get_transitions(
                     no_transition_cell, i) == (0, 0, 0, 0))
 
     # Facing south, going south
-    north_south_transition = rail_env_trans.set_transitions_from_orientation(
+    north_south_transition = rail_env_trans.set_transitions(
                     no_transition_cell, 2, (0, 0, 1, 0))
-    assert(rail_env_trans.set_transition_from_orientation_to_direction(
+    assert(rail_env_trans.set_transition(
                     north_south_transition, 2, 2, 0) == no_transition_cell)
-    assert(rail_env_trans.get_transition_from_orientation_to_direction(
+    assert(rail_env_trans.get_transition(
                     north_south_transition, 2, 2))
 
     # Facing north, going east
     south_east_transition = \
-        rail_env_trans.set_transition_from_orientation_to_direction(
+        rail_env_trans.set_transition(
          no_transition_cell, 0, 1, 1)
-    assert(rail_env_trans.get_transition_from_orientation_to_direction(
+    assert(rail_env_trans.get_transition(
             south_east_transition, 0, 1))
 
     # The opposite transitions are not feasible
-    assert(not rail_env_trans.get_transition_from_orientation_to_direction(
+    assert(not rail_env_trans.get_transition(
             north_south_transition, 2, 0))
-    assert(not rail_env_trans.get_transition_from_orientation_to_direction(
+    assert(not rail_env_trans.get_transition(
             south_east_transition, 2, 1))
 
     east_west_transition = rail_env_trans.rotate_transition(
@@ -52,10 +52,10 @@ def test_valid_railenv_transitions():
             south_east_transition, 180)
 
     # Facing west, going west
-    assert(rail_env_trans.get_transition_from_orientation_to_direction(
+    assert(rail_env_trans.get_transition(
             east_west_transition, 3, 3))
     # Facing south, going west
-    assert(rail_env_trans.get_transition_from_orientation_to_direction(
+    assert(rail_env_trans.get_transition(
             north_west_transition, 2, 3))
 
     assert(south_east_transition == rail_env_trans.rotate_transition(
@@ -63,16 +63,16 @@ def test_valid_railenv_transitions():
 
 
 def test_diagonal_transitions():
-    diagonal_trans_env = GridTransitions([], True)
+    diagonal_trans_env = Grid8Transitions([])
 
     # Facing north, going north-east
     south_northeast_transition = int('01000000' + '0'*8*7, 2)
-    assert(diagonal_trans_env.get_transitions_from_orientation(
+    assert(diagonal_trans_env.get_transitions(
             south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
 
     # Allowing transition from north to southwest: Facing south, going SW
     north_southwest_transition = \
-        diagonal_trans_env.set_transitions_from_orientation(
+        diagonal_trans_env.set_transitions(
          int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0))
 
     assert(diagonal_trans_env.rotate_transition(
-- 
GitLab