From 0933d128878968524bf8893b73997db7de90b3c3 Mon Sep 17 00:00:00 2001
From: SP Mohanty <spmohanty91@gmail.com>
Date: Wed, 3 Apr 2019 09:26:54 +0200
Subject: [PATCH] Add RailEnvTransitions as a separate derived class

---
 flatland/core/transitions.py | 76 +++++++++++++++++++-----------------
 1 file changed, 40 insertions(+), 36 deletions(-)

diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index dcbde0d8..2b161965 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -229,46 +229,50 @@ class GridTransitions(Transitions):
         return cell_transition
 
 
-"""
-Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
-of transitions mimicking the types of real Swiss rail connections.
+class RailEnvTransitions(GridTransitions):
+    """
+    Special case of `GridTransitions' over a 2D-grid, with a pre-defined set
+    of transitions mimicking the types of real Swiss rail connections.
 
------------------------------------------------------------------------------------------------
+    -----------------------------------------------------------------------------------------------
 
-The possible transitions for RailEnv from a cell to its neighboring ones
-are represented over 16 bits.
+    The possible transitions for RailEnv from a cell to its neighboring ones
+    are represented over 16 bits.
 
-Whether a transition is allowed or not depends on which direction an agent
-inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
-direction the agent wants to move to
-(North, East, South, West, relative to the cell).
-Each transition (orientation, direction) can be allowed (1) or forbidden (0).
+    Whether a transition is allowed or not depends on which direction an agent
+    inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which
+    direction the agent wants to move to
+    (North, East, South, West, relative to the cell).
+    Each transition (orientation, direction) can be allowed (1) or forbidden (0).
 
-The 16 bits are organized in 4 blocks of 4 bits each, the direction that
-the agent is facing.
-E.g., the most-significant 4-bits represent the possible movements (NESW)
-if the agent is facing North, etc...
+    The 16 bits are organized in 4 blocks of 4 bits each, the direction that
+    the agent is facing.
+    E.g., the most-significant 4-bits represent the possible movements (NESW)
+    if the agent is facing North, etc...
 
-agent's direction:          North    East   South   West
-agent's allowed movements:  [nesw]   [nesw] [nesw]  [nesw]
-example:                     0010     0000   1000    0000
+    agent's direction:          North    East   South   West
+    agent's allowed movements:  [nesw]   [nesw] [nesw]  [nesw]
+    example:                     0010     0000   1000    0000
 
-In the example, the agent can move from North to South and viceversa.
-"""
+    In the example, the agent can move from North to South and viceversa.
+    """
 
-"""
-transitions[] is indexed by case type/id, and returns the 4x4-bit [NESW]
-transitions available as a function of the agent's orientation
-(north, east, south, west)
-"""
-RailEnvTransitionsList = [int('0000000000000000', 2),
-                          int('1000000000100000', 2),
-                          int('1001001000100000', 2),
-                          int('1000010000100001', 2),
-                          int('1001011000100001', 2),
-                          int('1100110000110011', 2),
-                          int('0101001000000010', 2),
-                          int('0000000000100000', 2)]
-
-RailEnvTransitions = GridTransitions(transitions=RailEnvTransitionsList,
-                                     allow_diagonal_transitions=False)
+    """
+    transitions[] is indexed by case type/id, and returns the 4x4-bit [NESW]
+    transitions available as a function of the agent's orientation
+    (north, east, south, west)
+    """
+    transition_list = [int('0000000000000000', 2),
+                       int('1000000000100000', 2),
+                       int('1001001000100000', 2),
+                       int('1000010000100001', 2),
+                       int('1001011000100001', 2),
+                       int('1100110000110011', 2),
+                       int('0101001000000010', 2),
+                       int('0000000000100000', 2)]
+
+    def __init__(self):
+        super(RailEnvTransitions, self).__init__(
+            transitions=self.transition_list,
+            allow_diagonal_transitions=False
+        )
-- 
GitLab