From e1e0947e1299abac425c01bab3721fa128478324 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Fri, 5 Jul 2019 21:48:34 +0200
Subject: [PATCH] refactoring transitions_map

---
 flatland/core/grid/grid4.py                |  3 +-
 flatland/core/transition_map.py            | 36 +++++++++++++++-------
 flatland/envs/grid4_generators_utils.py    |  8 ++---
 flatland/envs/observations.py              | 16 +++++-----
 flatland/envs/predictions.py               |  2 +-
 flatland/envs/rail_env.py                  |  4 +--
 flatland/utils/editor.py                   |  6 ++--
 flatland/utils/rendertools.py              |  8 ++---
 tests/test_flatland_core_transition_map.py | 35 +++++++++++++++++----
 9 files changed, 78 insertions(+), 40 deletions(-)

diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py
index 714123ed..b4b5b17c 100644
--- a/flatland/core/grid/grid4.py
+++ b/flatland/core/grid/grid4.py
@@ -1,4 +1,5 @@
 from enum import IntEnum
+from typing import Type
 
 import numpy as np
 
@@ -218,7 +219,7 @@ class Grid4Transitions(Transitions):
         cell_transition = value
         return cell_transition
 
-    def get_direction_enum(self) -> IntEnum:
+    def get_direction_enum(self) -> Type[Grid4TransitionsEnum]:
         return Grid4TransitionsEnum
 
     def has_deadend(self, cell_transition):
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index cb09a628..5e0f6cd7 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -7,6 +7,7 @@ from importlib_resources import path
 from numpy import array
 
 from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.transitions import Transitions
 
 
 class TransitionMap:
@@ -110,7 +111,7 @@ class GridTransitionMap(TransitionMap):
     GridTransitionMap implements utility functions.
     """
 
-    def __init__(self, width, height, transitions=Grid4Transitions([])):
+    def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])):
         """
         Builder for GridTransitionMap object.
 
@@ -132,7 +133,25 @@ class GridTransitionMap(TransitionMap):
 
         self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
 
-    def get_transitions(self, cell_id):
+    def get_full_transitions(self, row, column):
+        """
+        Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
+
+        Parameters
+        ----------
+        row: int
+        column: int
+            (row,column) specifies the cell in this transition map.
+
+        Returns
+        -------
+        self.transitions.get_type()
+            The cell content int the format of this map's Transitions.
+
+        """
+        return self.grid[row][column]
+
+    def get_transitions(self, row, column, orientation):
         """
         Return a tuple of transitions available in a cell specified by
         `cell_id' (e.g., a tuple of size of the maximum number of transitions,
@@ -150,15 +169,10 @@ class GridTransitionMap(TransitionMap):
         Returns
         -------
         tuple
-            List of the validity of transitions in the cell.
+            List of the validity of transitions in the cell as given by the maps transitions.
 
         """
-        assert len(cell_id) in (2, 3), \
-            'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.'
-        if len(cell_id) == 3:
-            return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
-        elif len(cell_id) == 2:
-            return self.grid[cell_id[0]][cell_id[1]]
+        return self.transitions.get_transitions(self.grid[row][column], orientation)
 
     def set_transitions(self, cell_id, new_transitions):
         """
@@ -308,7 +322,7 @@ class GridTransitionMap(TransitionMap):
         grcPos = array(rcPos)
         grcMax = self.grid.shape
 
-        binTrans = self.get_transitions(rcPos)  # 16bit integer - all trans in/out
+        binTrans = self.get_full_transitions(*rcPos)  # 16bit integer - all trans in/out
         lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8)  # 2 x uint8
         g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4)  # 4x4 x uint8 binary(0,1)
         gDirOut = g2binTrans.any(axis=0)  # outbound directions as boolean array (4)
@@ -328,7 +342,7 @@ class GridTransitionMap(TransitionMap):
 
             # Get the transitions out of gPos2, using iDirOut as the inbound direction
             # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
-            t4Trans2 = self.get_transitions((*gPos2, iDirOut))
+            t4Trans2 = self.get_transitions(*gPos2, iDirOut)
             if any(t4Trans2):
                 continue
             else:
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index 4b2ab8cb..dedd76b6 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -75,7 +75,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
                 return 1
             if node not in visited:
                 visited.add(node)
-                moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
+                moves = rail.get_transitions(node[0][0], node[0][1], node[1])
                 for move_index in range(4):
                     if moves[move_index]:
                         stack.append((get_new_position(node[0], move_index),
@@ -84,7 +84,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
                 # If cell is a dead-end, append previous node with reversed
                 # orientation!
                 nbits = 0
-                tmp = rail.get_transitions((node[0][0], node[0][1]))
+                tmp = rail.get_full_transitions(node[0][0], node[0][1])
                 while tmp > 0:
                     nbits += (tmp & 1)
                     tmp = tmp >> 1
@@ -96,7 +96,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
     valid_positions = []
     for r in range(rail.height):
         for c in range(rail.width):
-            if rail.get_transitions((r, c)) > 0:
+            if rail.get_full_transitions(r, c) > 0:
                 valid_positions.append((r, c))
 
     re_generate = True
@@ -116,7 +116,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
             valid_movements = []
             for direction in range(4):
                 position = agents_position[i]
-                moves = rail.get_transitions((position[0], position[1], direction))
+                moves = rail.get_transitions(position[0], position[1], direction)
                 for move_index in range(4):
                     if moves[move_index]:
                         valid_movements.append((direction, move_index))
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 1f02d518..add983c0 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -253,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         if handle > len(self.env.agents):
             print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
         agent = self.env.agents[handle]  # TODO: handle being treated as index
-        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
+        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
         # Root node - current position
@@ -383,8 +383,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                 last_is_target = True
                 break
 
-            cell_transitions = self.env.rail.get_transitions((*position, direction))
-            total_transitions = bin(self.env.rail.get_transitions(position)).count("1")
+            cell_transitions = self.env.rail.get_transitions(*position, direction)
+            total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1")
             num_transitions = np.count_nonzero(cell_transitions)
             exploring = False
             # Detect Switches that can only be used by other agents.
@@ -394,7 +394,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if num_transitions == 1:
                 # Check if dead-end, or if we can go forward along direction
                 nbits = 0
-                tmp = self.env.rail.get_transitions(tuple(position))
+                tmp = self.env.rail.get_full_transitions(*position)
                 while tmp > 0:
                     nbits += (tmp & 1)
                     tmp = tmp >> 1
@@ -469,7 +469,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions((*position, direction))
+        possible_transitions = self.env.rail.get_transitions(*position, direction)
         for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
             if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                  (branch_direction + 2) % 4):
@@ -572,7 +572,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
         self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
         for i in range(self.rail_obs.shape[0]):
             for j in range(self.rail_obs.shape[1]):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
@@ -630,7 +630,7 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
         self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
         for i in range(self.rail_obs.shape[0]):
             for j in range(self.rail_obs.shape[1]):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i, j] = np.array(bitlist)
 
@@ -701,7 +701,7 @@ class LocalObsForRailEnv(ObservationBuilder):
                                   self.env.width + 2 * self.view_radius, 16))
         for i in range(self.env.height):
             for j in range(self.env.width):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index d471596b..2605e84c 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -131,7 +131,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                     prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
                     continue
                 # Take shortest possible path
-                cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
+                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
 
                 new_position = None
                 new_direction = None
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index c7050550..4e8832ec 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -322,7 +322,7 @@ class RailEnv(Environment):
                 new_position,
                 np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
             and  # check the new position has some transitions (ie is not an empty cell)
-            self.rail.get_transitions(new_position) > 0)
+            self.rail.get_full_transitions(*new_position) > 0)
 
         # If transition validity hasn't been checked yet.
         if transition_isValid is None:
@@ -338,7 +338,7 @@ class RailEnv(Environment):
 
     def check_action(self, agent, action):
         transition_isValid = None
-        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
+        possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
         new_direction = agent.direction
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index dd0c4d4b..ea4056fa 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -494,7 +494,7 @@ class EditorModel(object):
 
         if len(lrcStroke) >= 2:
             # If the first cell in a stroke is empty, add a deadend to cell 0
-            if self.env.rail.get_transitions(lrcStroke[0]) == 0:
+            if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0:
                 self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=0)
 
         # Add transitions for groups of 3 cells
@@ -504,7 +504,7 @@ class EditorModel(object):
 
         # If final cell empty, insert deadend:
         if len(lrcStroke) == 2:
-            if self.env.rail.get_transitions(lrcStroke[1]) == 0:
+            if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0:
                 self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1)
 
         # now empty out the final two cells from the queue
@@ -752,7 +752,7 @@ class EditorModel(object):
             self.log(*args, **kwargs)
 
     def debug_cell(self, rcCell):
-        binTrans = self.env.rail.get_transitions(rcCell)
+        binTrans = self.env.rail.get_full_transitions(*rcCell)
         sbinTrans = format(binTrans, "#018b")[2:]
         self.debug("cell ",
                    rcCell,
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index a148a5ae..6fa68f02 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -86,7 +86,7 @@ class RenderTool(object):
 
         for visit in lVisits:
             # transition for next cell
-            tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
+            tbTrans = self.env.rail.get_transitions(*visit.rc, visit.iDir)
             giTrans = np.where(tbTrans)[0]  # RC list of transitions
             gTransRCAg = rt.gTransRC[giTrans]
             self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
@@ -125,7 +125,7 @@ class RenderTool(object):
         )
         """
 
-        tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
+        tbTrans = self.env.rail.get_transitions(*rcPos, iDir)
         giTrans = np.where(tbTrans)[0]  # RC list of transitions
 
         # HACK: workaround dead-end transitions
@@ -459,7 +459,7 @@ class RenderTool(object):
                 xyCentre = array([x0, y1]) + cell_size / 2
 
                 # cell transition values
-                oCell = env.rail.get_transitions((r, c))
+                oCell = env.rail.get_full_transitions(r, c)
 
                 bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)
 
@@ -482,7 +482,7 @@ class RenderTool(object):
                     from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
                     from_xy = coords[from_ori]
 
-                    tMoves = env.rail.get_transitions((r, c, orientation))
+                    tMoves = env.rail.get_transitions(r, c, orientation)
 
                     for to_ori in range(4):
                         to_xy = coords[to_ori]
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 8013e912..a4142316 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -5,19 +5,42 @@ from flatland.core.transition_map import GridTransitionMap
 
 def test_grid4_get_transitions():
     grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
-    assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
+    assert grid4_map.get_full_transitions(0, 0) == 0
+
     grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
-    assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (1, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
+    assert grid4_map.get_full_transitions(0, 0) == pow(2, 15)  # the most significant bit is on
+
+    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
+    # the most significant and the fourth most significant bits are on
+    assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)
+
     grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0)
-    assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
+    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
+    # the fourth most significant bits are on
+    assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
 
 
 def test_grid8_set_transitions():
     grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
-    assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
+    assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
     grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1)
-    assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0)
+    assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0)
     grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
-    assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
+    assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
 
 # TODO GridTransitionMap
-- 
GitLab