From 25bb98a0c5e908751d9bee21e48fd3f5380f8db7 Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Mon, 15 Apr 2019 14:06:15 +0200
Subject: [PATCH] Added TransitionMap objects + replaced relevant use of grid
 cell

---
 examples/temporary_example.py        |   6 +-
 flatland/core/env.py                 |  40 ++---
 flatland/core/transitionmap.py       | 251 +++++++++++++++++++++++++++
 flatland/utils/rail_env_generator.py |  12 +-
 flatland/utils/rendertools.py        |  19 +-
 tests/test_rendertools.py            |  41 ++---
 6 files changed, 298 insertions(+), 71 deletions(-)
 create mode 100644 flatland/core/transitionmap.py

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index cd6d42de..2ea68cfd 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -17,7 +17,7 @@ env = RailEnv(rail, number_of_agents=10)
 env.reset()
 
 env_renderer = RenderTool(env)
-env_renderer.renderEnv()
+env_renderer.renderEnv(show=True)
 
 
 # Example generate a rail given a manual specification,
@@ -37,7 +37,7 @@ env.agents_target = [[1, 1]]
 env.agents_direction = [1]
 
 env_renderer = RenderTool(env)
-env_renderer.renderEnv()
+env_renderer.renderEnv(show=True)
 
 
 print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
@@ -64,4 +64,4 @@ for step in range(100):
             i = i+1
         i += 1
 
-    env_renderer.renderEnv()
+    env_renderer.renderEnv(show=True)
diff --git a/flatland/core/env.py b/flatland/core/env.py
index 7cc3b768..d6493507 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -5,8 +5,6 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv
 """
 import random
 
-from .transitions import RailEnvTransitions
-
 
 class Environment:
     """
@@ -133,8 +131,8 @@ class RailEnv:
         """
 
         self.rail = rail
-        self.width = len(self.rail[0])
-        self.height = len(self.rail)
+        self.width = rail.width
+        self.height = rail.height
 
         self.number_of_agents = number_of_agents
 
@@ -144,8 +142,6 @@ class RailEnv:
 
         self.agents_handles = list(range(self.number_of_agents))
 
-        self.trans = RailEnvTransitions()
-
     def get_agent_handles(self):
         return self.agents_handles
 
@@ -159,7 +155,7 @@ class RailEnv:
             valid_positions = []
             for r in range(self.height):
                 for c in range(self.width):
-                    if self.rail[r][c] > 0:
+                    if self.rail.get_transitions((r, c)) > 0:
                         valid_positions.append((r, c))
 
             self.agents_position = random.sample(valid_positions,
@@ -175,8 +171,8 @@ class RailEnv:
                 valid_movements = []
                 for direction in range(4):
                     position = self.agents_position[i]
-                    moves = self.trans.get_transitions(
-                             self.rail[position[0]][position[1]], direction)
+                    moves = self.rail.get_transitions(
+                            (position[0], position[1], direction))
                     for move_index in range(4):
                         if moves[move_index]:
                             valid_movements.append((direction, move_index))
@@ -251,8 +247,9 @@ class RailEnv:
                 if action == 2:
                     # compute number of possible transitions in the current
                     # cell
+                    is_deadend = False
                     nbits = 0
-                    tmp = self.rail[pos[0]][pos[1]]
+                    tmp = self.rail.get_transitions((pos[0], pos[1]))
                     while tmp > 0:
                         nbits += (tmp & 1)
                         tmp = tmp >> 1
@@ -270,14 +267,13 @@ class RailEnv:
                         elif direction == 3:
                             reverse_direction = 1
 
-                        valid_transition = self.trans.get_transition(
-                                            self.rail[pos[0]][pos[1]],
-                                            reverse_direction,
+                        valid_transition = self.rail.get_transition(
+                                            (pos[0], pos[1], direction),
                                             reverse_direction)
-
                         if valid_transition:
                             direction = reverse_direction
-                            movement = direction
+                            movement = reverse_direction
+                            is_deadend = True
 
                 new_position = self._new_position(pos, movement)
 
@@ -289,15 +285,14 @@ class RailEnv:
                    new_position[0] < 0 or new_position[1] < 0:
                     new_cell_isValid = False
 
-                elif self.rail[new_position[0]][new_position[1]] > 0:
+                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
                     new_cell_isValid = True
                 else:
                     new_cell_isValid = False
 
-                transition_isValid = self.trans.get_transition(
-                     self.rail[pos[0]][pos[1]],
-                     direction,
-                     movement)
+                transition_isValid = self.rail.get_transition(
+                     (pos[0], pos[1], direction),
+                     movement) or is_deadend
 
                 cell_isFree = True
                 for j in range(self.number_of_agents):
@@ -363,8 +358,7 @@ class RailEnv:
                 return 1
             if node not in visited:
                 visited.add(node)
-                moves = self.trans.get_transitions(
-                         self.rail[node[0][0]][node[0][1]], node[1])
+                moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
                 for move_index in range(4):
                     if moves[move_index]:
                         stack.append((self._new_position(node[0], move_index),
@@ -373,7 +367,7 @@ class RailEnv:
                 # If cell is a dead-end, append previous node with reversed
                 # orientation!
                 nbits = 0
-                tmp = self.rail[node[0][0]][node[0][1]]
+                tmp = self.rail.get_transitions((node[0][0], node[0][1]))
                 while tmp > 0:
                     nbits += (tmp & 1)
                     tmp = tmp >> 1
diff --git a/flatland/core/transitionmap.py b/flatland/core/transitionmap.py
new file mode 100644
index 00000000..d3fcf5c8
--- /dev/null
+++ b/flatland/core/transitionmap.py
@@ -0,0 +1,251 @@
+"""
+TransitionMap and derived classes.
+"""
+
+import numpy as np
+
+from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
+
+
+class TransitionMap:
+    """
+    Base TransitionMap class.
+
+    Generic class that implements a collection of transitions over a set of
+    cells.
+    """
+
+    def get_transitions(self, cell_id):
+        """
+        Return a tuple of transitions available in a cell specified by
+        `cell_id' (e.g., a tuple of size of the maximum number of transitions,
+        with values 0 or 1, or potentially in between,
+        for stochastic transitions).
+
+        Parameters
+        ----------
+        cell_id : [cell identifier]
+            The cell_id object depends on the specific implementation.
+            It generally is an int (e.g., an index) or a tuple of indices.
+
+        Returns
+        -------
+        tuple
+            List of the validity of transitions in the cell.
+
+        """
+        raise NotImplementedError()
+
+    def set_transitions(self, cell_id, new_transitions):
+        """
+        Replaces the available transitions in cell `cell_id' with the tuple
+        `new_transitions'. `new_transitions' must have
+        one element for each possible transition.
+
+        Parameters
+        ----------
+        cell_id : [cell identifier]
+            The cell_id object depends on the specific implementation.
+            It generally is an int (e.g., an index) or a tuple of indices.
+        new_transitions : tuple
+            Tuple of new transitions validitiy for the cell.
+
+        """
+        raise NotImplementedError()
+
+    def get_transition(self, cell_id, transition_index):
+        """
+        Return the status of whether an agent in cell `cell_id' can perform a
+        movement along transition `transition_index (e.g., the NESW direction
+        of movement, for agents on a grid).
+
+        Parameters
+        ----------
+        cell_id : [cell identifier]
+            The cell_id object depends on the specific implementation.
+            It generally is an int (e.g., an index) or a tuple of indices.
+        transition_index : int
+            Index of the transition to probe, as index in the tuple returned by
+            get_transitions(). e.g., the NESW direction of movement, for agents
+            on a grid.
+
+        Returns
+        -------
+        int or float (depending on derived class)
+            Validity of the requested transition (e.g.,
+            0/1 allowed/not allowed, a probability in [0,1], etc...)
+
+        """
+        raise NotImplementedError()
+
+    def set_transition(self, cell_id, transition_index, new_transition):
+        """
+        Replaces the validity of transition to `transition_index' in cell
+        `cell_id' with the new `new_transition'.
+
+
+        Parameters
+        ----------
+        cell_id : [cell identifier]
+            The cell_id object depends on the specific implementation.
+            It generally is an int (e.g., an index) or a tuple of indices.
+        transition_index : int
+            Index of the transition to probe, as index in the tuple returned by
+            get_transitions(). e.g., the NESW direction of movement, for agents
+            on a grid.
+        new_transition : int or float (depending on derived class)
+            Validity of the requested transition (e.g.,
+            0/1 allowed/not allowed, a probability in [0,1], etc...)
+
+        """
+        raise NotImplementedError()
+
+
+class GridTransitionMap(TransitionMap):
+    """
+    Implements a TransitionMap over a 2D grid.
+
+    GridTransitionMap implements utility functions.
+    """
+
+    def __init__(self, width, height, transitions=Grid4Transitions([])):
+        """
+        Builder for GridTransitionMap object.
+
+        Parameters
+        ----------
+        width : int
+            Width of the grid.
+        height : int
+            Height of the grid.
+        transitions_class : Transitions object
+            The Transitions object to use to encode/decode transitions over the
+            grid.
+
+        """
+
+        self.width = width
+        self.height = height
+        self.transitions = transitions
+
+        if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions):
+            self.grid = np.ndarray((height, width), dtype=np.uint16)
+        elif isinstance(self.transitions, Grid8Transitions):
+            self.grid = np.ndarray((height, width), dtype=np.uint64)
+
+    def get_transitions(self, cell_id):
+        """
+        Return a tuple of transitions available in a cell specified by
+        `cell_id' (e.g., a tuple of size of the maximum number of transitions,
+        with values 0 or 1, or potentially in between,
+        for stochastic transitions).
+
+        Parameters
+        ----------
+        cell_id : tuple
+            The cell_id indices a cell as (column, row, orientation),
+            where orientation is the direction an agent is facing within a cell.
+            Alternatively, it can be accessed as (column, row) to return the
+            full cell content.
+
+        Returns
+        -------
+        tuple
+            List of the validity of transitions in the cell.
+
+        """
+        if len(cell_id) == 3:
+            return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
+        elif len(cell_id) == 2:
+            return self.grid[cell_id[0]][cell_id[1]]
+        else:
+            print('GridTransitionMap.get_transitions() ERROR: \
+                   wrong cell_id tuple.')
+            return ()
+
+    def set_transitions(self, cell_id, new_transitions):
+        """
+        Replaces the available transitions in cell `cell_id' with the tuple
+        `new_transitions'. `new_transitions' must have
+        one element for each possible transition.
+
+        Parameters
+        ----------
+        cell_id : tuple
+            The cell_id indices a cell as (column, row, orientation),
+            where orientation is the direction an agent is facing within a cell.
+            Alternatively, it can be accessed as (column, row) to replace the
+            full cell content.
+        new_transitions : tuple
+            Tuple of new transitions validitiy for the cell.
+
+        """
+        if len(cell_id) == 3:
+            self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2], new_transitions)
+        elif len(cell_id) == 2:
+            self.grid[cell_id[0]][cell_id[1]] = new_transitions
+        else:
+            print('GridTransitionMap.get_transitions() ERROR: \
+                   wrong cell_id tuple.')
+
+    def get_transition(self, cell_id, transition_index):
+        """
+        Return the status of whether an agent in cell `cell_id' can perform a
+        movement along transition `transition_index (e.g., the NESW direction
+        of movement, for agents on a grid).
+
+        Parameters
+        ----------
+        cell_id : tuple
+            The cell_id indices a cell as (column, row, orientation),
+            where orientation is the direction an agent is facing within a cell.
+        transition_index : int
+            Index of the transition to probe, as index in the tuple returned by
+            get_transitions(). e.g., the NESW direction of movement, for agents
+            on a grid.
+
+        Returns
+        -------
+        int or float (depending on derived class)
+            Validity of the requested transition (e.g.,
+            0/1 allowed/not allowed, a probability in [0,1], etc...)
+
+        """
+        if len(cell_id) != 3:
+            print('GridTransitionMap.get_transition() ERROR: \
+                   wrong cell_id tuple.')
+            return ()
+        return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index)
+
+    def set_transition(self, cell_id, transition_index, new_transition):
+        """
+        Replaces the validity of transition to `transition_index' in cell
+        `cell_id' with the new `new_transition'.
+
+
+        Parameters
+        ----------
+        cell_id : tuple
+            The cell_id indices a cell as (column, row, orientation),
+            where orientation is the direction an agent is facing within a cell.
+        transition_index : int
+            Index of the transition to probe, as index in the tuple returned by
+            get_transitions(). e.g., the NESW direction of movement, for agents
+            on a grid.
+        new_transition : int or float (depending on derived class)
+            Validity of the requested transition (e.g.,
+            0/1 allowed/not allowed, a probability in [0,1], etc...)
+
+        """
+        if len(cell_id) != 3:
+            print('GridTransitionMap.set_transition() ERROR: \
+                   wrong cell_id tuple.')
+            return
+        self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition)
+
+
+# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
+# (most general implementation) or to make Grid-class specific methods for
+# slicing over the 3 dimensions?  I'd say both perhaps.
+
+# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?)
diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py
index 4d73d520..5b292f03 100644
--- a/flatland/utils/rail_env_generator.py
+++ b/flatland/utils/rail_env_generator.py
@@ -6,6 +6,7 @@ import random
 import numpy as np
 
 from flatland.core.transitions import RailEnvTransitions
+from flatland.core.transitionmap import GridTransitionMap
 
 
 def generate_rail_from_manual_specifications(rail_spec):
@@ -30,7 +31,7 @@ def generate_rail_from_manual_specifications(rail_spec):
 
     height = len(rail_spec)
     width = len(rail_spec[0])
-    rail = np.zeros((height, width), dtype=np.uint16)
+    rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
 
     for r in range(height):
         for c in range(width):
@@ -38,8 +39,8 @@ def generate_rail_from_manual_specifications(rail_spec):
             if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
                 print("ERROR - invalid cell type=", cell[0])
                 return []
-            rail[r, c] = t_utils.rotate_transition(
-                          t_utils.transitions[cell[0]], cell[1])
+            rail.set_transitions((r, c), t_utils.rotate_transition(
+                          t_utils.transitions[cell[0]], cell[1]))
 
     return rail
 
@@ -300,4 +301,7 @@ def generate_random_rail(width, height):
             if rail[r][c] is None:
                 rail[r][c] = int('0000000000000000', 2)
 
-    return np.asarray(rail, dtype=np.uint16)
+    tmp_rail = np.asarray(rail, dtype=np.uint16)
+    return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
+    return_rail.grid = tmp_rail
+    return return_rail
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 0068928e..fc426f80 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -4,7 +4,6 @@ import numpy as np
 from numpy import array
 import xarray as xr
 import matplotlib.pyplot as plt
-from flatland.core.transitions import RailEnvTransitions
 
 
 class RenderTool(object):
@@ -25,7 +24,6 @@ class RenderTool(object):
     gCentres = xr.DataArray(gGrid,
                             dims=["xy", "p1", "p2"],
                             coords={"xy": ["x", "y"]}) + xyPixHalf
-    RETrans = RailEnvTransitions()
 
     def __init__(self, env):
         self.env = env
@@ -56,16 +54,14 @@ class RenderTool(object):
                     # TODO: this was `rcDir' but it was undefined
                     rcNext = rcPos + iDir
                     # transition for next cell
-                    oTrans = self.env.rail[rcNext[0]][rcNext[1]]
-                    tbTrans = RailEnvTransitions. \
-                        get_transitions(oTrans, iDir)
+                    tbTrans = self.env.rail. \
+                        get_transitions((rcNext[0], rcNext[1], iDir))
                     giTrans = np.where(tbTrans)[0]  # RC list of transitions
                     gTransRCAg = self.__class__.gTransRC[giTrans]
 
         for visit in lVisits:
             # transition for next cell
-            oTrans = self.env.rail[visit.rc]
-            tbTrans = rt.RETrans.get_transitions(oTrans, visit.iDir)
+            tbTrans = self.env.rail.get_transitions((visit.rc[0], visit.rc[1], visit.iDir))
             giTrans = np.where(tbTrans)[0]  # RC list of transitions
             gTransRCAg = rt.gTransRC[giTrans]
             self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
@@ -102,11 +98,9 @@ class RenderTool(object):
             [0, 1] #  available transition indices, ie N, E
         )
         """
-        rt = self.__class__
 
         # TODO: suggest we provide an accessor in RailEnv
-        oTrans = self.env.rail[rcPos]   # transition for current cell
-        tbTrans = rt.RETrans.get_transitions(oTrans, iDir)
+        tbTrans = self.env.get_transitions((rcPos[0], rcPos[1], iDir))
         giTrans = np.where(tbTrans)[0]  # RC list of transitions
 
         # HACK: workaround dead-end transitions
@@ -406,7 +400,6 @@ class RenderTool(object):
                         ])
                     plt.plot(*xyArrow.T, color=sColor)
 
-        RETrans = RailEnvTransitions()
         env = self.env
 
         # Draw cells grid
@@ -442,7 +435,7 @@ class RenderTool(object):
                 xyCentre = array([x0, y1]) + cell_size / 2
 
                 # cell transition values
-                oCell = env.rail[r, c]
+                oCell = env.rail.get_transitions((r, c))
 
                 # Special Case 7, with a single bit; terminate at center
                 nbits = 0
@@ -463,7 +456,7 @@ class RenderTool(object):
                     # renderer.push()
                     # renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS)
 
-                    tMoves = RETrans.get_transitions(oCell, orientation)
+                    tMoves = env.rail.get_transitions((r, c, orientation))
 
                     # to_ori = (orientation + 2) % 4
                     for to_ori in range(4):
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index c68f84db..2ae69015 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -1,32 +1,20 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+"""
+Tests for `flatland` package.
+"""
 
 from flatland.core.env import RailEnv
-#from flatland.core.transitions import GridTransitions
 import numpy as np
 import random
 import os
 
-from recordtype import recordtype
-
-import numpy as np
-from numpy import array
-import xarray as xr
 import matplotlib.pyplot as plt
 
-from flatland.core.transitions import RailEnvTransitions
-#import flatland.core.env
 from flatland.utils import rail_env_generator
-from flatland.core.env import RailEnv
 import flatland.utils.rendertools as rt
 
 
-
-
-"""Tests for `flatland` package."""
-
-
-
 def checkFrozenImage(sFileImage):
     sDirRoot = "."
     sTmpFileImage = sDirRoot + "/images/test/" + sFileImage
@@ -37,7 +25,7 @@ def checkFrozenImage(sFileImage):
     plt.savefig(sTmpFileImage)
 
     bytesFrozenImage = None
-    for sDir in [ "/images/", "/images/test/" ]:
+    for sDir in ["/images/", "/images/test/"]:
         sfPath = sDirRoot + sDir + sFileImage
         bytesImage = plt.imread(sfPath)
         if bytesFrozenImage is None:
@@ -49,37 +37,34 @@ def checkFrozenImage(sFileImage):
 
 def test_render_env():
     random.seed(100)
-    oRail = rail_env_generator.generate_random_rail(10,10)
+    oRail = rail_env_generator.generate_random_rail(10, 10)
     type(oRail), len(oRail)
     oEnv = RailEnv(oRail, number_of_agents=2)
     oEnv.reset()
     oRT = rt.RenderTool(oEnv)
-    plt.figure(figsize=(10,10))
+    plt.figure(figsize=(10, 10))
     oRT.renderEnv()
 
     checkFrozenImage("basic-env.png")
 
-    plt.figure(figsize=(10,10))
+    plt.figure(figsize=(10, 10))
     oRT.renderEnv()
-    
+
     lVisits = oRT.getTreeFromRail(
-        oEnv.agents_position[0], 
-        oEnv.agents_direction[0], 
+        oEnv.agents_position[0],
+        oEnv.agents_direction[0],
         nDepth=17, bPlot=True)
 
     checkFrozenImage("env-tree-spatial.png")
-    
-    plt.figure(figsize=(8,8))
+
+    plt.figure(figsize=(8, 8))
     xyTarg = oRT.env.agents_target[0]
     visitDest = oRT.plotTree(lVisits, xyTarg)
 
     checkFrozenImage("env-tree-graph.png")
 
-
-    oFig = plt.figure(figsize=(10,10))
+    plt.figure(figsize=(10, 10))
     oRT.renderEnv()
     oRT.plotPath(visitDest)
 
     checkFrozenImage("env-path.png")
-
-
-- 
GitLab