From bbf582e64a08ff1bf52b92045e3a51a9406ed156 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Sun, 5 May 2019 15:55:41 +0100 Subject: [PATCH] add remove_deadends flag to set_transition in Grid4Transition. Add has_deadend and remove_deadends() in RailEnvTransitions. (should they be in RailEnv or Grid4...?) --- flatland/core/transitions.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 4b1874a8..a862c6a8 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -253,8 +253,7 @@ class Grid4Transitions(Transitions): """ return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 - def set_transition(self, cell_transition, orientation, direction, - new_transition): + def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): """ Set the transition bit (1 value) that determines whether an agent oriented in direction `orientation' and inside a cell with transitions @@ -271,7 +270,8 @@ class Grid4Transitions(Transitions): Direction of movement whose validity is to be tested. new_transition : int Validity of the requested transition: 0/1 allowed/not allowed. - + remove_deadends -- boolean, default False + remove all deadend transitions. Returns ------- int @@ -285,6 +285,9 @@ class Grid4Transitions(Transitions): else: cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) + if remove_deadends: + cell_transition = self.remove_deadends(cell_transition) + return cell_transition def rotate_transition(self, cell_transition, rotation=0): @@ -548,6 +551,10 @@ class RailEnvTransitions(Grid4Transitions): super(RailEnvTransitions, self).__init__( transitions=self.transition_list ) + + # These bits represent all the possible dead ends + self.maskDeadEnds = 0b0010000110000100 + # create this to make validation faster self.transitions_all = [] for index, trans in enumerate(self.transitions): @@ -621,11 +628,11 @@ class RailEnvTransitions(Grid4Transitions): return False def has_deadend(self, cell_transition): - binDeadends = 0b0010000110000100 - if cell_transition & binDeadends > 0: + if cell_transition & self.maskDeadEnds > 0: return True else: return False - # def remove_deadends(self, cell_transition) - \ No newline at end of file + def remove_deadends(self, cell_transition): + cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff + return cell_transition -- GitLab