diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 4b1874a880a086324f7da8aa48e31529ca6985a3..a862c6a81dbd71a0aa95f2305158f940e964c0ad 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -253,8 +253,7 @@ class Grid4Transitions(Transitions): """ return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 - def set_transition(self, cell_transition, orientation, direction, - new_transition): + def set_transition(self, cell_transition, orientation, direction, new_transition, remove_deadends=False): """ Set the transition bit (1 value) that determines whether an agent oriented in direction `orientation' and inside a cell with transitions @@ -271,7 +270,8 @@ class Grid4Transitions(Transitions): Direction of movement whose validity is to be tested. new_transition : int Validity of the requested transition: 0/1 allowed/not allowed. - + remove_deadends -- boolean, default False + remove all deadend transitions. Returns ------- int @@ -285,6 +285,9 @@ class Grid4Transitions(Transitions): else: cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) + if remove_deadends: + cell_transition = self.remove_deadends(cell_transition) + return cell_transition def rotate_transition(self, cell_transition, rotation=0): @@ -548,6 +551,10 @@ class RailEnvTransitions(Grid4Transitions): super(RailEnvTransitions, self).__init__( transitions=self.transition_list ) + + # These bits represent all the possible dead ends + self.maskDeadEnds = 0b0010000110000100 + # create this to make validation faster self.transitions_all = [] for index, trans in enumerate(self.transitions): @@ -621,11 +628,11 @@ class RailEnvTransitions(Grid4Transitions): return False def has_deadend(self, cell_transition): - binDeadends = 0b0010000110000100 - if cell_transition & binDeadends > 0: + if cell_transition & self.maskDeadEnds > 0: return True else: return False - # def remove_deadends(self, cell_transition) - \ No newline at end of file + def remove_deadends(self, cell_transition): + cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff + return cell_transition