From bbf582e64a08ff1bf52b92045e3a51a9406ed156 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Sun, 5 May 2019 15:55:41 +0100
Subject: [PATCH] add remove_deadends flag to set_transition in
 Grid4Transition. Add has_deadend and remove_deadends() in RailEnvTransitions.
 (should they be in RailEnv or Grid4...?)

---
 flatland/core/transitions.py | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 4b1874a8..a862c6a8 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -253,8 +253,7 @@ class Grid4Transitions(Transitions):
         """
         return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
 
-    def set_transition(self, cell_transition, orientation, direction,
-                       new_transition):
+    def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False):
         """
         Set the transition bit (1 value) that determines whether an agent
         oriented in direction `orientation' and inside a cell with transitions
@@ -271,7 +270,8 @@ class Grid4Transitions(Transitions):
             Direction of movement whose validity is to be tested.
         new_transition : int
             Validity of the requested transition: 0/1 allowed/not allowed.
-
+        remove_deadends -- boolean, default False
+            remove all deadend transitions.
         Returns
         -------
         int
@@ -285,6 +285,9 @@ class Grid4Transitions(Transitions):
         else:
             cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
 
+        if remove_deadends:
+            cell_transition = self.remove_deadends(cell_transition)
+
         return cell_transition
 
     def rotate_transition(self, cell_transition, rotation=0):
@@ -548,6 +551,10 @@ class RailEnvTransitions(Grid4Transitions):
         super(RailEnvTransitions, self).__init__(
             transitions=self.transition_list
         )
+        
+        # These bits represent all the possible dead ends
+        self.maskDeadEnds = 0b0010000110000100
+
         # create this to make validation faster
         self.transitions_all = []
         for index, trans in enumerate(self.transitions):
@@ -621,11 +628,11 @@ class RailEnvTransitions(Grid4Transitions):
         return False
 
     def has_deadend(self, cell_transition):
-        binDeadends = 0b0010000110000100
-        if cell_transition & binDeadends > 0:
+        if cell_transition & self.maskDeadEnds > 0:
             return True
         else:
             return False
 
-    # def remove_deadends(self, cell_transition)
-    
\ No newline at end of file
+    def remove_deadends(self, cell_transition):
+        cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
+        return cell_transition
-- 
GitLab